diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..3510f3d74f2d7bf3a726ecf4659027ba7e986e87
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,135 @@
+.vscode
+
+# ignored files
+version.py
+
+# ignored files with suffix
+*.html
+*.png
+*.jpeg
+*.jpg
+*.pt
+*.gif
+*.pth
+*.dat
+*.zip
+*.so
+
+# template
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# project
+results/
+experiments/
+tb_logger/
+build/
+
+run.sh
+*debug*
+*_old*
+
+*.swp
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..ebf8d30fd4477a396fc6f663a5152b07049cd720
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "basicsr/archs/gmflow"]
+ path = basicsr/archs/gmflow
+ url = https://github.com/haofeixu/gmflow
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..ff76d0f6b7c513e7d72a2a732d95762b007a4efe
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,35 @@
+S-Lab License 1.0
+
+Copyright 2024 S-Lab
+
+Redistribution and use for non-commercial purpose in source and
+binary forms, with or without modification, are permitted provided
+that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+In the event that redistribution and/or use for commercial purpose in
+source or binary forms, with or without modification is required,
+please contact the contributor(s) of the work.
diff --git a/README.md b/README.md
index 41495cda6077b71980b7a81ba7990c0638bd6bcd..37b1852c94e343d25f795b01ce791da06994e102 100644
--- a/README.md
+++ b/README.md
@@ -4,8 +4,8 @@ emoji: 🏢
colorFrom: blue
colorTo: red
sdk: gradio
-sdk_version: 5.29.0
-app_file: app.py
+sdk_version: 5.4.0
+app_file: hugging_face/app.py
pinned: false
license: other
short_description: Official demo of KEEP (ECCV'24) for face video SR
diff --git a/basicsr/VERSION b/basicsr/VERSION
new file mode 100644
index 0000000000000000000000000000000000000000..1892b926767774e9ba91f1e584fa71b4c56abb69
--- /dev/null
+++ b/basicsr/VERSION
@@ -0,0 +1 @@
+1.3.2
diff --git a/basicsr/__init__.py b/basicsr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7ffcccd7fc0f33b59d99d73d0436d60e561b0fc
--- /dev/null
+++ b/basicsr/__init__.py
@@ -0,0 +1,11 @@
+# https://github.com/xinntao/BasicSR
+# flake8: noqa
+from .archs import *
+from .data import *
+from .losses import *
+from .metrics import *
+from .models import *
+from .ops import *
+from .train import *
+from .utils import *
+from .version import __gitsha__, __version__
diff --git a/basicsr/archs/__init__.py b/basicsr/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52ce844726fa3922867f0580993bf0a3fa41c895
--- /dev/null
+++ b/basicsr/archs/__init__.py
@@ -0,0 +1,27 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ['build_network']
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with
+# '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0]
+ for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(
+ f'basicsr.archs.{file_name}') for file_name in arch_filenames]
+
+
+def build_network(opt):
+ opt = deepcopy(opt)
+ network_type = opt.pop('type')
+ net = ARCH_REGISTRY.get(network_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
+ return net
diff --git a/basicsr/archs/arcface_arch.py b/basicsr/archs/arcface_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe5afb7bd2b359e0c2b7efdf628ab10b63964d87
--- /dev/null
+++ b/basicsr/archs/arcface_arch.py
@@ -0,0 +1,245 @@
+import torch.nn as nn
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+def conv3x3(inplanes, outplanes, stride=1):
+ """A simple wrapper for 3x3 convolution with padding.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ outplanes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ """
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ """Basic residual block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class IRBlock(nn.Module):
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
+ super(IRBlock, self).__init__()
+ self.bn0 = nn.BatchNorm2d(inplanes)
+ self.conv1 = conv3x3(inplanes, inplanes)
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.prelu = nn.PReLU()
+ self.conv2 = conv3x3(inplanes, planes, stride)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.use_se = use_se
+ if self.use_se:
+ self.se = SEBlock(planes)
+
+ def forward(self, x):
+ residual = x
+ out = self.bn0(x)
+ out = self.conv1(out)
+ out = self.bn1(out)
+ out = self.prelu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.use_se:
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.prelu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+ expansion = 4 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class SEBlock(nn.Module):
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
+
+ Args:
+ channel (int): Channel number of inputs.
+ reduction (int): Channel reduction ration. Default: 16.
+ """
+
+ def __init__(self, channel, reduction=16):
+ super(SEBlock, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
+ nn.Sigmoid())
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+
+@ARCH_REGISTRY.register()
+class ResNetArcFace(nn.Module):
+ """ArcFace with ResNet architectures.
+
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
+
+ Args:
+ block (str): Block used in the ArcFace architecture.
+ layers (tuple(int)): Block numbers in each layer.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+
+ def __init__(self, block, layers, use_se=True):
+ if block == 'IRBlock':
+ block = IRBlock
+ self.inplanes = 64
+ self.use_se = use_se
+ super(ResNetArcFace, self).__init__()
+
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.prelu = nn.PReLU()
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.bn4 = nn.BatchNorm2d(512)
+ self.dropout = nn.Dropout()
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
+ self.bn5 = nn.BatchNorm1d(512)
+
+ # initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, num_blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
+ self.inplanes = planes
+ for _ in range(1, num_blocks):
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn4(x)
+ x = self.dropout(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc5(x)
+ x = self.bn5(x)
+
+ return x
\ No newline at end of file
diff --git a/basicsr/archs/arch_util.py b/basicsr/archs/arch_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad45ab34e901c47fb539152fca714a3795b0de2
--- /dev/null
+++ b/basicsr/archs/arch_util.py
@@ -0,0 +1,318 @@
+import collections.abc
+import math
+import torch
+import torchvision
+import warnings
+from distutils.version import LooseVersion
+from itertools import repeat
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+ """Initialize network weights.
+
+ Args:
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+ scale (float): Scale initialized weights, especially for residual
+ blocks. Default: 1.
+ bias_fill (float): The value to fill bias. Default: 0
+ kwargs (dict): Other arguments for initialization function.
+ """
+ if not isinstance(module_list, list):
+ module_list = [module_list]
+ for module in module_list:
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, _BatchNorm):
+ init.constant_(m.weight, 1)
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+
+ Args:
+ basic_block (nn.module): nn.module class for basic block.
+ num_basic_block (int): number of blocks.
+
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+class ResidualBlockNoBN(nn.Module):
+ """Residual block without BN.
+
+ It has a style of:
+ ---Conv-ReLU-Conv-+-
+ |________________|
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ res_scale (float): Residual scale. Default: 1.
+ pytorch_init (bool): If set to True, use pytorch default init,
+ otherwise, use default_init_weights. Default: False.
+ """
+
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+ super(ResidualBlockNoBN, self).__init__()
+ self.res_scale = res_scale
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ if not pytorch_init:
+ default_init_weights([self.conv1, self.conv2], 0.1)
+
+ def forward(self, x):
+ identity = x
+ out = self.conv2(self.relu(self.conv1(x)))
+ return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
+ """Warp an image or feature map with optical flow.
+
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Before pytorch 1.3, the default value is
+ align_corners=True. After pytorch 1.3, the default value is
+ align_corners=False. Here, we use the True as default.
+
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ assert x.size()[-2:] == flow.size()[1:3]
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
+ grid.requires_grad = False
+
+ vgrid = grid + flow
+ # scale grid to [-1,1]
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+ # TODO, what if align_corners=False
+ return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
+ """Resize a flow according to ratio or shape.
+
+ Args:
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+ size_type (str): 'ratio' or 'shape'.
+ sizes (list[int | float]): the ratio for resizing or the final output
+ shape.
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+ ratio > 1.0).
+ 2) The order of output_size should be [out_h, out_w].
+ interp_mode (str): The mode of interpolation for resizing.
+ Default: 'bilinear'.
+ align_corners (bool): Whether align corners. Default: False.
+
+ Returns:
+ Tensor: Resized flow.
+ """
+ _, _, flow_h, flow_w = flow.size()
+ if size_type == 'ratio':
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+ elif size_type == 'shape':
+ output_h, output_w = sizes[0], sizes[1]
+ else:
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
+
+ input_flow = flow.clone()
+ ratio_h = output_h / flow_h
+ ratio_w = output_w / flow_w
+ input_flow[:, 0, :, :] *= ratio_w
+ input_flow[:, 1, :, :] *= ratio_h
+ resized_flow = F.interpolate(
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+ return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+ """Modulated deformable conv for deformable alignment.
+
+ Different from the official DCNv2Pack, which generates offsets and masks
+ from the preceding features, this DCNv2Pack takes another different
+ features to generate offsets and masks.
+
+ Ref:
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
+ """
+
+ def forward(self, x, feat):
+ out = self.conv_offset(feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ offset_absmean = torch.mean(torch.abs(offset))
+ if offset_absmean > 50:
+ logger = get_root_logger()
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, mask)
+ else:
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ low = norm_cdf((a - mean) / std)
+ up = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [low, up], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution.
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+ The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
\ No newline at end of file
diff --git a/basicsr/archs/correlation.py b/basicsr/archs/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d32a583ecbab9314870eda3c8b6f59c3a08f7f4
--- /dev/null
+++ b/basicsr/archs/correlation.py
@@ -0,0 +1,426 @@
+import torch
+
+import cupy
+import re
+
+
+class Stream:
+ ptr = torch.cuda.current_stream().cuda_stream
+# end
+
+
+kernel_Correlation_rearrange = '''
+ extern "C" __global__ void kernel_Correlation_rearrange(
+ const int n,
+ const float* input,
+ float* output
+ ) {
+ int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
+
+ if (intIndex >= n) {
+ return;
+ }
+
+ int intSample = blockIdx.z;
+ int intChannel = blockIdx.y;
+
+ float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
+
+ __syncthreads();
+
+ int intPaddedY = (intIndex / SIZE_3(input)) + 4;
+ int intPaddedX = (intIndex % SIZE_3(input)) + 4;
+ int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
+
+ output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue;
+ }
+'''
+
+kernel_Correlation_updateOutput = '''
+ extern "C" __global__ void kernel_Correlation_updateOutput(
+ const int n,
+ const float* rbot0,
+ const float* rbot1,
+ float* top
+ ) {
+ extern __shared__ char patch_data_char[];
+
+ float *patch_data = (float *)patch_data_char;
+
+ // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
+ int x1 = blockIdx.x + 4;
+ int y1 = blockIdx.y + 4;
+ int item = blockIdx.z;
+ int ch_off = threadIdx.x;
+
+ // Load 3D patch into shared shared memory
+ for (int j = 0; j < 1; j++) { // HEIGHT
+ for (int i = 0; i < 1; i++) { // WIDTH
+ int ji_off = (j + i) * SIZE_3(rbot0);
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
+ int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
+ int idxPatchData = ji_off + ch;
+ patch_data[idxPatchData] = rbot0[idx1];
+ }
+ }
+ }
+
+ __syncthreads();
+
+ __shared__ float sum[32];
+
+ // Compute correlation
+ for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
+ sum[ch_off] = 0;
+
+ int s2o = top_channel % 9 - 4;
+ int s2p = top_channel / 9 - 4;
+
+ for (int j = 0; j < 1; j++) { // HEIGHT
+ for (int i = 0; i < 1; i++) { // WIDTH
+ int ji_off = (j + i) * SIZE_3(rbot0);
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
+ int x2 = x1 + s2o;
+ int y2 = y1 + s2p;
+
+ int idxPatchData = ji_off + ch;
+ int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
+
+ sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
+ }
+ }
+ }
+
+ __syncthreads();
+
+ if (ch_off == 0) {
+ float total_sum = 0;
+ for (int idx = 0; idx < 32; idx++) {
+ total_sum += sum[idx];
+ }
+ const int sumelems = SIZE_3(rbot0);
+ const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
+ top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
+ }
+ }
+ }
+'''
+
+kernel_Correlation_updateGradFirst = '''
+ #define ROUND_OFF 50000
+
+ extern "C" __global__ void kernel_Correlation_updateGradFirst(
+ const int n,
+ const int intSample,
+ const float* rbot0,
+ const float* rbot1,
+ const float* gradOutput,
+ float* gradFirst,
+ float* gradSecond
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ int n = intIndex % SIZE_1(gradFirst); // channels
+ int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
+ int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
+
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
+ // We use a large offset, for the inner part not to become negative.
+ const int round_off = ROUND_OFF;
+ const int round_off_s1 = round_off;
+
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
+ int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
+ int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
+
+ // Same here:
+ int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
+ int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
+
+ float sum = 0;
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
+ xmin = max(0,xmin);
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
+
+ ymin = max(0,ymin);
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
+
+ for (int p = -4; p <= 4; p++) {
+ for (int o = -4; o <= 4; o++) {
+ // Get rbot1 data:
+ int s2o = o;
+ int s2p = p;
+ int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
+ float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
+
+ // Index offset for gradOutput in following loops:
+ int op = (p+4) * 9 + (o+4); // index[o,p]
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
+
+ for (int y = ymin; y <= ymax; y++) {
+ for (int x = xmin; x <= xmax; x++) {
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
+ sum += gradOutput[idxgradOutput] * bot1tmp;
+ }
+ }
+ }
+ }
+ }
+ const int sumelems = SIZE_1(gradFirst);
+ const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
+ gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
+ } }
+'''
+
+kernel_Correlation_updateGradSecond = '''
+ #define ROUND_OFF 50000
+
+ extern "C" __global__ void kernel_Correlation_updateGradSecond(
+ const int n,
+ const int intSample,
+ const float* rbot0,
+ const float* rbot1,
+ const float* gradOutput,
+ float* gradFirst,
+ float* gradSecond
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ int n = intIndex % SIZE_1(gradSecond); // channels
+ int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
+ int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
+
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
+ // We use a large offset, for the inner part not to become negative.
+ const int round_off = ROUND_OFF;
+ const int round_off_s1 = round_off;
+
+ float sum = 0;
+ for (int p = -4; p <= 4; p++) {
+ for (int o = -4; o <= 4; o++) {
+ int s2o = o;
+ int s2p = p;
+
+ //Get X,Y ranges and clamp
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
+ int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
+ int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
+
+ // Same here:
+ int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
+ int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
+
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
+ xmin = max(0,xmin);
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
+
+ ymin = max(0,ymin);
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
+
+ // Get rbot0 data:
+ int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
+ float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
+
+ // Index offset for gradOutput in following loops:
+ int op = (p+4) * 9 + (o+4); // index[o,p]
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
+
+ for (int y = ymin; y <= ymax; y++) {
+ for (int x = xmin; x <= xmax; x++) {
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
+ sum += gradOutput[idxgradOutput] * bot0tmp;
+ }
+ }
+ }
+ }
+ }
+ const int sumelems = SIZE_1(gradSecond);
+ const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
+ gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
+ } }
+'''
+
+
+def cupy_kernel(strFunction, objectVariables):
+ strKernel = globals()[strFunction]
+
+ while True:
+ objectMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
+
+ if objectMatch is None:
+ break
+ # end
+
+ intArg = int(objectMatch.group(2))
+
+ strTensor = objectMatch.group(4)
+ intSizes = objectVariables[strTensor].size()
+
+ strKernel = strKernel.replace(
+ objectMatch.group(), str(intSizes[intArg]))
+ # end
+
+ while True:
+ objectMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
+
+ if objectMatch is None:
+ break
+ # end
+
+ intArgs = int(objectMatch.group(2))
+ strArgs = objectMatch.group(4).split(',')
+
+ strTensor = strArgs[0]
+ intStrides = objectVariables[strTensor].stride()
+ strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip(
+ ) + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs)]
+
+ strKernel = strKernel.replace(objectMatch.group(
+ 0), strTensor + '[' + str.join('+', strIndex) + ']')
+ # end
+
+ return strKernel
+# end
+
+# @cupy.util.memoize(for_each_device=True)
+
+
+@cupy.memoize(for_each_device=True)
+def cupy_launch(strFunction, strKernel):
+ return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
+# end
+
+
+class _FunctionCorrelation(torch.autograd.Function):
+ @staticmethod
+ def forward(self, first, second):
+ rbot0 = first.new_zeros([first.size(0), first.size(
+ 2) + 8, first.size(3) + 8, first.size(1)])
+ rbot1 = first.new_zeros([first.size(0), first.size(
+ 2) + 8, first.size(3) + 8, first.size(1)])
+
+ self.save_for_backward(first, second, rbot0, rbot1)
+
+ assert(first.is_contiguous() == True)
+ assert(second.is_contiguous() == True)
+
+ output = first.new_zeros(
+ [first.size(0), 81, first.size(2), first.size(3)])
+
+ if first.is_cuda == True:
+ n = first.size(2) * first.size(3)
+ cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
+ 'input': first,
+ 'output': rbot0
+ }))(
+ grid=tuple([int((n + 16 - 1) / 16),
+ first.size(1), first.size(0)]),
+ block=tuple([16, 1, 1]),
+ args=[n, first.data_ptr(), rbot0.data_ptr()],
+ stream=Stream
+ )
+
+ n = second.size(2) * second.size(3)
+ cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
+ 'input': second,
+ 'output': rbot1
+ }))(
+ grid=tuple([int((n + 16 - 1) / 16),
+ second.size(1), second.size(0)]),
+ block=tuple([16, 1, 1]),
+ args=[n, second.data_ptr(), rbot1.data_ptr()],
+ stream=Stream
+ )
+
+ n = output.size(1) * output.size(2) * output.size(3)
+ cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
+ 'rbot0': rbot0,
+ 'rbot1': rbot1,
+ 'top': output
+ }))(
+ grid=tuple([output.size(3), output.size(2), output.size(0)]),
+ block=tuple([32, 1, 1]),
+ shared_mem=first.size(1) * 4,
+ args=[n, rbot0.data_ptr(), rbot1.data_ptr(),
+ output.data_ptr()],
+ stream=Stream
+ )
+
+ elif first.is_cuda == False:
+ raise NotImplementedError()
+
+ # end
+
+ return output
+ # end
+
+ @staticmethod
+ def backward(self, gradOutput):
+ first, second, rbot0, rbot1 = self.saved_tensors
+
+ assert(gradOutput.is_contiguous() == True)
+
+ gradFirst = first.new_zeros([first.size(0), first.size(1), first.size(
+ 2), first.size(3)]) if self.needs_input_grad[0] == True else None
+ gradSecond = first.new_zeros([first.size(0), first.size(1), first.size(
+ 2), first.size(3)]) if self.needs_input_grad[1] == True else None
+
+ if first.is_cuda == True:
+ if gradFirst is not None:
+ for intSample in range(first.size(0)):
+ n = first.size(1) * first.size(2) * first.size(3)
+ cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
+ 'rbot0': rbot0,
+ 'rbot1': rbot1,
+ 'gradOutput': gradOutput,
+ 'gradFirst': gradFirst,
+ 'gradSecond': None
+ }))(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(
+ ), gradOutput.data_ptr(), gradFirst.data_ptr(), None],
+ stream=Stream
+ )
+ # end
+ # end
+
+ if gradSecond is not None:
+ for intSample in range(first.size(0)):
+ n = first.size(1) * first.size(2) * first.size(3)
+ cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
+ 'rbot0': rbot0,
+ 'rbot1': rbot1,
+ 'gradOutput': gradOutput,
+ 'gradFirst': None,
+ 'gradSecond': gradSecond
+ }))(
+ grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(
+ ), gradOutput.data_ptr(), None, gradSecond.data_ptr()],
+ stream=Stream
+ )
+ # end
+ # end
+
+ elif first.is_cuda == False:
+ raise NotImplementedError()
+
+ # end
+
+ return gradFirst, gradSecond
+ # end
+# end
+
+
+def FunctionCorrelation(tensorFirst, tensorSecond):
+ return _FunctionCorrelation.apply(tensorFirst, tensorSecond)
+# end
+
+
+class ModuleCorrelation(torch.nn.Module):
+ def __init__(self):
+ super(ModuleCorrelation, self).__init__()
+ # end
+
+ def forward(self, tensorFirst, tensorSecond):
+ return _FunctionCorrelation.apply(tensorFirst, tensorSecond)
+ # end
+# end
diff --git a/basicsr/archs/gmflow/.gitignore b/basicsr/archs/gmflow/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..68bc17f9ff2104a9d7b6777058bb4c343ca72609
--- /dev/null
+++ b/basicsr/archs/gmflow/.gitignore
@@ -0,0 +1,160 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/basicsr/archs/gmflow/LICENSE b/basicsr/archs/gmflow/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..8ba17c78e378819527e65ef7d1a767f035a792ac
--- /dev/null
+++ b/basicsr/archs/gmflow/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2022, Haofei Xu
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/basicsr/archs/gmflow/README.md b/basicsr/archs/gmflow/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..17449970f8861dec7fef8d8835fc7e92abeb2332
--- /dev/null
+++ b/basicsr/archs/gmflow/README.md
@@ -0,0 +1,239 @@
+# GMFlow
+
+
+Official PyTorch implementation of paper:
+
+[**GMFlow: Learning Optical Flow via Global Matching**](https://arxiv.org/abs/2111.13680), **CVPR 2022, Oral**
+
+Authors: [Haofei Xu](https://haofeixu.github.io/), [Jing Zhang](https://scholar.google.com.hk/citations?user=9jH5v74AAAAJ), [Jianfei Cai](https://jianfei-cai.github.io/), [Hamid Rezatofighi](https://scholar.google.com/citations?user=VxAuxMwAAAAJ), [Dacheng Tao](https://scholar.google.com/citations?user=RwlJNLcAAAAJ)
+
+
+**11/15/2022 Update: Check out our new work: [Unifying Flow, Stereo and Depth Estimation](https://haofeixu.github.io/unimatch/) and code: [unimatch](https://github.com/autonomousvision/unimatch) for extending GMFlow to stereo and depth tasks. [More pretrained GMFlow models](https://github.com/autonomousvision/unimatch/blob/master/MODEL_ZOO.md) with different speed-accuracy trade-offs are also released. Check out our [Colab](https://colab.research.google.com/drive/1r5m-xVy3Kw60U-m5VB-aQ98oqqg_6cab?usp=sharing) and [HuggingFace](https://huggingface.co/spaces/haofeixu/unimatch) demo to play with GMFlow in your browser!**
+
+
+
+**A [video introduction](https://www.bilibili.com/video/BV18A4y1R7PL) (in Chinese) of GMFlow is available at bilibili!**
+
+
+
+https://user-images.githubusercontent.com/19343475/174446408-520b8a6c-9714-4ff3-978c-98e23ab29c1f.mp4
+
+
+
+
+
+We streamline the optical flow estimation pipeline by reformulating optical flow as a **global matching** problem.
+
+
+
+
+

+
+
+
+
+
+## Highlights
+
+- **Flexible & Modular design**
+
+ We decompose the end-to-end optical flow framework into five components:
+
+ feature extraction, feature enhancement, feature matching, flow propagation and flow refinement.
+
+ One can easily construct a customized optical flow model by combining different components.
+
+- **High accuracy**
+
+ With only one refinement, GMFlow outperforms 31-refinements RAFT on the challenging Sintel benchmark.
+
+- **High efficiency**
+
+ A basic GMFlow model (without refinement) runs at 57ms (V100) or 26ms (A100) for Sintel data (436x1024).
+
+ GMFlow gains more speedup than RAFT on high-end GPUs (e.g., A100) since GMFlow doesn't require a large number of sequential computation.
+
+ GMFlow also simplifies backward flow computation without requiring to forward the network twice. The bidirectional flow can be used for occlusion detection with forward-backward consistency check.
+
+ 
+
+
+
+
+## Installation
+
+Our code is based on pytorch 1.9.0, CUDA 10.2 and python 3.8. Higher version pytorch should also work well.
+
+We recommend using [conda](https://www.anaconda.com/distribution/) for installation:
+
+```
+conda env create -f environment.yml
+conda activate gmflow
+```
+
+## Demos
+
+All pretrained models can be downloaded from [google drive](https://drive.google.com/file/d/1d5C5cgHIxWGsFR1vYs5XrQbbUiZl9TX2/view?usp=sharing).
+
+
+
+You can run a trained model on a sequence of images and visualize the results:
+
+```
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--inference_dir demo/sintel_market_1 \
+--output_path output/gmflow-norefine-sintel_market_1 \
+--resume pretrained/gmflow_sintel-0c07dcb3.pth
+```
+
+You can also predict bidirectional flow with `--pred_bidir_flow` enabled and use `--fwd_bwd_consistency_check` for forward-backward consistency check. More examples can be found in [scripts/demo.sh](scripts/demo.sh).
+
+
+
+## Datasets
+
+The datasets used to train and evaluate GMFlow are as follows:
+
+* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
+* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
+* [Sintel](http://sintel.is.tue.mpg.de/)
+* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
+* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/)
+
+By default the dataloader [datasets.py](data/datasets.py) assumes the datasets are located in folder `datasets` and are organized as follows:
+
+```
+datasets
+├── FlyingChairs_release
+│ └── data
+├── FlyingThings3D
+│ ├── frames_cleanpass
+│ ├── frames_finalpass
+│ └── optical_flow
+├── HD1K
+│ ├── hd1k_challenge
+│ ├── hd1k_flow_gt
+│ ├── hd1k_flow_uncertainty
+│ └── hd1k_input
+├── KITTI
+│ ├── testing
+│ └── training
+├── Sintel
+│ ├── test
+│ └── training
+```
+
+It is recommended to symlink your dataset root to `datasets`:
+
+```shell
+ln -s $YOUR_DATASET_ROOT datasets
+```
+
+Otherwise, you may need to change the corresponding paths in [datasets.py](data/datasets.py).
+
+
+
+## Evaluation
+
+You can evaluate a trained GMFlow model by running:
+
+```
+CUDA_VISIBLE_DEVICES=0 python main.py --eval --val_dataset things sintel --resume pretrained/gmflow_things-e9887eda.pth
+```
+
+More evaluation scripts can be found in [scripts/evaluate.sh](scripts/evaluate.sh).
+
+
+
+For submission to Sintel and KITTI online test sets, you can run [scripts/submission.sh](scripts/submission.sh).
+
+
+
+## Training
+
+All training scripts on FlyingChairs, FlyingThings3D, Sintel and KITTI datasets can be found in [scripts/train_gmflow.sh](scripts/train_gmflow.sh) and [scripts/train_gmflow_with_refine.sh](scripts/train_gmflow_with_refine.sh).
+
+Note that the basic GMFlow model (without refinement) can be trained on 4x 16GB V100 GPUs. For training GMFlow with refinement, 8x 16GB V100 or 4x 32GB V100 or 4x 40GB A100 GPUs are required by default. You may need to tune the batch size and training iterations according to your hardware.
+
+
+
+We support using tensorboard to monitor and visualize the training process. You can first start a tensorboard session with
+
+```shell
+tensorboard --logdir checkpoints
+```
+
+and then access [http://localhost:6006](http://localhost:6006) in your browser.
+
+
+
+## Citation
+
+If you find our work useful in your research, please consider citing our paper:
+
+```
+@inproceedings{xu2022gmflow,
+ title={GMFlow: Learning Optical Flow via Global Matching},
+ author={Xu, Haofei and Zhang, Jing and Cai, Jianfei and Rezatofighi, Hamid and Tao, Dacheng},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={8121-8130},
+ year={2022}
+}
+```
+
+
+
+## Acknowledgements
+
+This project would not have been possible without relying on some awesome repos : [RAFT](https://github.com/princeton-vl/RAFT), [LoFTR](https://github.com/zju3dv/LoFTR), [DETR](https://github.com/facebookresearch/detr), [Swin](https://github.com/microsoft/Swin-Transformer), [mmdetection](https://github.com/open-mmlab/mmdetection) and [Detectron2](https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py). We thank the original authors for their excellent work.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/basicsr/archs/gmflow/data/__init__.py b/basicsr/archs/gmflow/data/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..895b3281e7af148f74ecbc13a16d329863aeb49a
--- /dev/null
+++ b/basicsr/archs/gmflow/data/__init__.py
@@ -0,0 +1,7 @@
+from .datasets import build_train_dataset
+from .datasets import (FlyingChairs,
+ FlyingThings3D,
+ MpiSintel,
+ KITTI,
+ HD1K,
+ )
diff --git a/basicsr/archs/gmflow/data/chairs_split.txt b/basicsr/archs/gmflow/data/chairs_split.txt
new file mode 100755
index 0000000000000000000000000000000000000000..6ae8f0b72a22fc061552604c94664e3a0287914e
--- /dev/null
+++ b/basicsr/archs/gmflow/data/chairs_split.txt
@@ -0,0 +1,22872 @@
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+2
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+1
+1
+1
+1
+1
+1
+1
+2
+1
+1
+1
+1
+1
\ No newline at end of file
diff --git a/basicsr/archs/gmflow/data/datasets.py b/basicsr/archs/gmflow/data/datasets.py
new file mode 100755
index 0000000000000000000000000000000000000000..6e2f1584f9c013fb0e4d4ac331d856da363e0c9b
--- /dev/null
+++ b/basicsr/archs/gmflow/data/datasets.py
@@ -0,0 +1,312 @@
+# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
+
+import numpy as np
+import torch
+import torch.utils.data as data
+
+import os
+import random
+from glob import glob
+import os.path as osp
+
+from utils import frame_utils
+from data.transforms import FlowAugmentor, SparseFlowAugmentor
+
+
+class FlowDataset(data.Dataset):
+ def __init__(self, aug_params=None, sparse=False,
+ load_occlusion=False,
+ ):
+ self.augmentor = None
+ self.sparse = sparse
+
+ if aug_params is not None:
+ if sparse:
+ self.augmentor = SparseFlowAugmentor(**aug_params)
+ else:
+ self.augmentor = FlowAugmentor(**aug_params)
+
+ self.is_test = False
+ self.init_seed = False
+ self.flow_list = []
+ self.image_list = []
+ self.extra_info = []
+
+ self.load_occlusion = load_occlusion
+ self.occ_list = []
+
+ def __getitem__(self, index):
+
+ if self.is_test:
+ img1 = frame_utils.read_gen(self.image_list[index][0])
+ img2 = frame_utils.read_gen(self.image_list[index][1])
+
+ img1 = np.array(img1).astype(np.uint8)[..., :3]
+ img2 = np.array(img2).astype(np.uint8)[..., :3]
+
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+
+ return img1, img2, self.extra_info[index]
+
+ if not self.init_seed:
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None:
+ torch.manual_seed(worker_info.id)
+ np.random.seed(worker_info.id)
+ random.seed(worker_info.id)
+ self.init_seed = True
+
+ index = index % len(self.image_list)
+ valid = None
+
+ if self.sparse:
+ flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) # [H, W, 2], [H, W]
+ else:
+ flow = frame_utils.read_gen(self.flow_list[index])
+
+ if self.load_occlusion:
+ occlusion = frame_utils.read_gen(self.occ_list[index]) # [H, W], 0 or 255 (occluded)
+
+ img1 = frame_utils.read_gen(self.image_list[index][0])
+ img2 = frame_utils.read_gen(self.image_list[index][1])
+
+ flow = np.array(flow).astype(np.float32)
+ img1 = np.array(img1).astype(np.uint8)
+ img2 = np.array(img2).astype(np.uint8)
+
+ if self.load_occlusion:
+ occlusion = np.array(occlusion).astype(np.float32)
+
+ # grayscale images
+ if len(img1.shape) == 2:
+ img1 = np.tile(img1[..., None], (1, 1, 3))
+ img2 = np.tile(img2[..., None], (1, 1, 3))
+ else:
+ img1 = img1[..., :3]
+ img2 = img2[..., :3]
+
+ if self.augmentor is not None:
+ if self.sparse:
+ img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
+ else:
+ if self.load_occlusion:
+ img1, img2, flow, occlusion = self.augmentor(img1, img2, flow, occlusion=occlusion)
+ else:
+ img1, img2, flow = self.augmentor(img1, img2, flow)
+
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+ flow = torch.from_numpy(flow).permute(2, 0, 1).float()
+
+ if self.load_occlusion:
+ occlusion = torch.from_numpy(occlusion) # [H, W]
+
+ if valid is not None:
+ valid = torch.from_numpy(valid)
+ else:
+ valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
+
+ # mask out occluded pixels
+ if self.load_occlusion:
+ # non-occlusion: 0, occlusion: 255
+ noc_valid = 1 - occlusion / 255. # 0 or 1
+
+ return img1, img2, flow, valid.float(), noc_valid.float()
+
+ return img1, img2, flow, valid.float()
+
+ def __rmul__(self, v):
+ self.flow_list = v * self.flow_list
+ self.image_list = v * self.image_list
+
+ return self
+
+ def __len__(self):
+ return len(self.image_list)
+
+
+class MpiSintel(FlowDataset):
+ def __init__(self, aug_params=None, split='training',
+ root='datasets/Sintel',
+ dstype='clean',
+ load_occlusion=False,
+ ):
+ super(MpiSintel, self).__init__(aug_params,
+ load_occlusion=load_occlusion,
+ )
+
+ flow_root = osp.join(root, split, 'flow')
+ image_root = osp.join(root, split, dstype)
+
+ if load_occlusion:
+ occlusion_root = osp.join(root, split, 'occlusions')
+
+ if split == 'test':
+ self.is_test = True
+
+ for scene in os.listdir(image_root):
+ image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
+ for i in range(len(image_list) - 1):
+ self.image_list += [[image_list[i], image_list[i + 1]]]
+ self.extra_info += [(scene, i)] # scene and frame_id
+
+ if split != 'test':
+ self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
+
+ if load_occlusion:
+ self.occ_list += sorted(glob(osp.join(occlusion_root, scene, '*.png')))
+
+
+class FlyingChairs(FlowDataset):
+ def __init__(self, aug_params=None, split='train',
+ root='datasets/FlyingChairs_release/data',
+ ):
+ super(FlyingChairs, self).__init__(aug_params)
+
+ images = sorted(glob(osp.join(root, '*.ppm')))
+ flows = sorted(glob(osp.join(root, '*.flo')))
+ assert (len(images) // 2 == len(flows))
+
+ split_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chairs_split.txt')
+ split_list = np.loadtxt(split_file, dtype=np.int32)
+ for i in range(len(flows)):
+ xid = split_list[i]
+ if (split == 'training' and xid == 1) or (split == 'validation' and xid == 2):
+ self.flow_list += [flows[i]]
+ self.image_list += [[images[2 * i], images[2 * i + 1]]]
+
+
+class FlyingThings3D(FlowDataset):
+ def __init__(self, aug_params=None,
+ root='datasets/FlyingThings3D',
+ dstype='frames_cleanpass',
+ test_set=False,
+ validate_subset=True,
+ ):
+ super(FlyingThings3D, self).__init__(aug_params)
+
+ img_dir = root
+ flow_dir = root
+
+ for cam in ['left']:
+ for direction in ['into_future', 'into_past']:
+ if test_set:
+ image_dirs = sorted(glob(osp.join(img_dir, dstype, 'TEST/*/*')))
+ else:
+ image_dirs = sorted(glob(osp.join(img_dir, dstype, 'TRAIN/*/*')))
+ image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
+
+ if test_set:
+ flow_dirs = sorted(glob(osp.join(flow_dir, 'optical_flow/TEST/*/*')))
+ else:
+ flow_dirs = sorted(glob(osp.join(flow_dir, 'optical_flow/TRAIN/*/*')))
+ flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
+
+ for idir, fdir in zip(image_dirs, flow_dirs):
+ images = sorted(glob(osp.join(idir, '*.png')))
+ flows = sorted(glob(osp.join(fdir, '*.pfm')))
+ for i in range(len(flows) - 1):
+ if direction == 'into_future':
+ self.image_list += [[images[i], images[i + 1]]]
+ self.flow_list += [flows[i]]
+ elif direction == 'into_past':
+ self.image_list += [[images[i + 1], images[i]]]
+ self.flow_list += [flows[i + 1]]
+
+ # validate on 1024 subset of test set for fast speed
+ if test_set and validate_subset:
+ num_val_samples = 1024
+ all_test_samples = len(self.image_list) # 7866
+
+ stride = all_test_samples // num_val_samples
+ remove = all_test_samples % num_val_samples
+
+ # uniformly sample a subset
+ self.image_list = self.image_list[:-remove][::stride]
+ self.flow_list = self.flow_list[:-remove][::stride]
+
+
+class KITTI(FlowDataset):
+ def __init__(self, aug_params=None, split='training',
+ root='datasets/KITTI',
+ ):
+ super(KITTI, self).__init__(aug_params, sparse=True,
+ )
+ if split == 'testing':
+ self.is_test = True
+
+ root = osp.join(root, split)
+ images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
+ images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
+
+ for img1, img2 in zip(images1, images2):
+ frame_id = img1.split('/')[-1]
+ self.extra_info += [[frame_id]]
+ self.image_list += [[img1, img2]]
+
+ if split == 'training':
+ self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
+
+
+class HD1K(FlowDataset):
+ def __init__(self, aug_params=None, root='datasets/HD1K'):
+ super(HD1K, self).__init__(aug_params, sparse=True)
+
+ seq_ix = 0
+ while 1:
+ flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
+ images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
+
+ if len(flows) == 0:
+ break
+
+ for i in range(len(flows) - 1):
+ self.flow_list += [flows[i]]
+ self.image_list += [[images[i], images[i + 1]]]
+
+ seq_ix += 1
+
+
+def build_train_dataset(args):
+ """ Create the data loader for the corresponding training set """
+ if args.stage == 'chairs':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
+
+ train_dataset = FlyingChairs(aug_params, split='training')
+
+ elif args.stage == 'things':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
+
+ clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
+ final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
+ train_dataset = clean_dataset + final_dataset
+
+ elif args.stage == 'sintel':
+ # 1041 pairs for clean and final each
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
+
+ things = FlyingThings3D(aug_params, dstype='frames_cleanpass') # 40302
+
+ sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
+ sintel_final = MpiSintel(aug_params, split='training', dstype='final')
+
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}
+
+ kitti = KITTI(aug_params=aug_params) # 200
+
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}
+
+ hd1k = HD1K(aug_params=aug_params) # 1047
+
+ train_dataset = 100 * sintel_clean + 100 * sintel_final + 200 * kitti + 5 * hd1k + things
+
+ elif args.stage == 'kitti':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
+
+ train_dataset = KITTI(aug_params, split='training',
+ )
+ else:
+ raise ValueError(f'stage {args.stage} is not supported')
+
+ return train_dataset
diff --git a/basicsr/archs/gmflow/data/transforms.py b/basicsr/archs/gmflow/data/transforms.py
new file mode 100755
index 0000000000000000000000000000000000000000..5b1188f3833c97c50429dd5c9644fb5dab3166d7
--- /dev/null
+++ b/basicsr/archs/gmflow/data/transforms.py
@@ -0,0 +1,284 @@
+import numpy as np
+import cv2
+from PIL import Image
+from torchvision.transforms import ColorJitter
+
+
+class FlowAugmentor:
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True,
+ no_eraser_aug=True,
+ ):
+ # spatial augmentation params
+ self.crop_size = crop_size
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.spatial_aug_prob = 0.8
+ self.stretch_prob = 0.8
+ self.max_stretch = 0.2
+
+ # flip augmentation params
+ self.do_flip = do_flip
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.1
+
+ # photometric augmentation params
+ self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14)
+
+ self.asymmetric_color_aug_prob = 0.2
+
+ if no_eraser_aug:
+ # we disable eraser aug since no obvious improvement is observed in our experiments
+ self.eraser_aug_prob = -1
+ else:
+ self.eraser_aug_prob = 0.5
+
+ def color_transform(self, img1, img2):
+ """ Photometric augmentation """
+
+ # asymmetric
+ if np.random.rand() < self.asymmetric_color_aug_prob:
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
+
+ # symmetric
+ else:
+ image_stack = np.concatenate([img1, img2], axis=0)
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ img1, img2 = np.split(image_stack, 2, axis=0)
+
+ return img1, img2
+
+ def eraser_transform(self, img1, img2, bounds=[50, 100]):
+ """ Occlusion augmentation """
+
+ ht, wd = img1.shape[:2]
+ if np.random.rand() < self.eraser_aug_prob:
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+ for _ in range(np.random.randint(1, 3)):
+ x0 = np.random.randint(0, wd)
+ y0 = np.random.randint(0, ht)
+ dx = np.random.randint(bounds[0], bounds[1])
+ dy = np.random.randint(bounds[0], bounds[1])
+ img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color
+
+ return img1, img2
+
+ def spatial_transform(self, img1, img2, flow, occlusion=None):
+ # randomly sample scale
+ ht, wd = img1.shape[:2]
+
+ min_scale = np.maximum(
+ (self.crop_size[0] + 8) / float(ht),
+ (self.crop_size[1] + 8) / float(wd))
+
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = scale
+ scale_y = scale
+ if np.random.rand() < self.stretch_prob:
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+
+ scale_x = np.clip(scale_x, min_scale, None)
+ scale_y = np.clip(scale_y, min_scale, None)
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = flow * [scale_x, scale_y]
+
+ if occlusion is not None:
+ occlusion = cv2.resize(occlusion, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+
+ if self.do_flip:
+ if np.random.rand() < self.h_flip_prob: # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+
+ if occlusion is not None:
+ occlusion = occlusion[:, ::-1]
+
+ if np.random.rand() < self.v_flip_prob: # v-flip
+ img1 = img1[::-1, :]
+ img2 = img2[::-1, :]
+ flow = flow[::-1, :] * [1.0, -1.0]
+
+ if occlusion is not None:
+ occlusion = occlusion[::-1, :]
+
+ # In case no cropping
+ if img1.shape[0] - self.crop_size[0] > 0:
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
+ else:
+ y0 = 0
+ if img1.shape[1] - self.crop_size[1] > 0:
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
+ else:
+ x0 = 0
+
+ img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+
+ if occlusion is not None:
+ occlusion = occlusion[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ return img1, img2, flow, occlusion
+
+ return img1, img2, flow
+
+ def __call__(self, img1, img2, flow, occlusion=None):
+ img1, img2 = self.color_transform(img1, img2)
+ img1, img2 = self.eraser_transform(img1, img2)
+
+ if occlusion is not None:
+ img1, img2, flow, occlusion = self.spatial_transform(
+ img1, img2, flow, occlusion)
+ else:
+ img1, img2, flow = self.spatial_transform(img1, img2, flow)
+
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+
+ if occlusion is not None:
+ occlusion = np.ascontiguousarray(occlusion)
+ return img1, img2, flow, occlusion
+
+ return img1, img2, flow
+
+
+class SparseFlowAugmentor:
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False,
+ no_eraser_aug=True,
+ ):
+ # spatial augmentation params
+ self.crop_size = crop_size
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.spatial_aug_prob = 0.8
+ self.stretch_prob = 0.8
+ self.max_stretch = 0.2
+
+ # flip augmentation params
+ self.do_flip = do_flip
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.1
+
+ # photometric augmentation params
+ self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14)
+ self.asymmetric_color_aug_prob = 0.2
+
+ if no_eraser_aug:
+ # we disable eraser aug since no obvious improvement is observed in our experiments
+ self.eraser_aug_prob = -1
+ else:
+ self.eraser_aug_prob = 0.5
+
+ def color_transform(self, img1, img2):
+ image_stack = np.concatenate([img1, img2], axis=0)
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ img1, img2 = np.split(image_stack, 2, axis=0)
+ return img1, img2
+
+ def eraser_transform(self, img1, img2):
+ ht, wd = img1.shape[:2]
+ if np.random.rand() < self.eraser_aug_prob:
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+ for _ in range(np.random.randint(1, 3)):
+ x0 = np.random.randint(0, wd)
+ y0 = np.random.randint(0, ht)
+ dx = np.random.randint(50, 100)
+ dy = np.random.randint(50, 100)
+ img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color
+
+ return img1, img2
+
+ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
+ ht, wd = flow.shape[:2]
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
+ coords = np.stack(coords, axis=-1)
+
+ coords = coords.reshape(-1, 2).astype(np.float32)
+ flow = flow.reshape(-1, 2).astype(np.float32)
+ valid = valid.reshape(-1).astype(np.float32)
+
+ coords0 = coords[valid >= 1]
+ flow0 = flow[valid >= 1]
+
+ ht1 = int(round(ht * fy))
+ wd1 = int(round(wd * fx))
+
+ coords1 = coords0 * [fx, fy]
+ flow1 = flow0 * [fx, fy]
+
+ xx = np.round(coords1[:, 0]).astype(np.int32)
+ yy = np.round(coords1[:, 1]).astype(np.int32)
+
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
+ xx = xx[v]
+ yy = yy[v]
+ flow1 = flow1[v]
+
+ flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
+ valid_img = np.zeros([ht1, wd1], dtype=np.int32)
+
+ flow_img[yy, xx] = flow1
+ valid_img[yy, xx] = 1
+
+ return flow_img, valid_img
+
+ def spatial_transform(self, img1, img2, flow, valid):
+ # randomly sample scale
+
+ ht, wd = img1.shape[:2]
+ min_scale = np.maximum(
+ (self.crop_size[0] + 1) / float(ht),
+ (self.crop_size[1] + 1) / float(wd))
+
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = np.clip(scale, min_scale, None)
+ scale_y = np.clip(scale, min_scale, None)
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+
+ flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
+
+ if self.do_flip:
+ if np.random.rand() < 0.5: # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+ valid = valid[:, ::-1]
+
+ margin_y = 20
+ margin_x = 50
+
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
+ x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
+
+ y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
+ x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
+
+ img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ valid = valid[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
+ return img1, img2, flow, valid
+
+ def __call__(self, img1, img2, flow, valid):
+ img1, img2 = self.color_transform(img1, img2)
+ img1, img2 = self.eraser_transform(img1, img2)
+
+ img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
+
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+ valid = np.ascontiguousarray(valid)
+
+ return img1, img2, flow, valid
diff --git a/basicsr/archs/gmflow/environment.yml b/basicsr/archs/gmflow/environment.yml
new file mode 100755
index 0000000000000000000000000000000000000000..f7e6fd86e66d7b5fad3a38aeb8c6ae02528ca439
--- /dev/null
+++ b/basicsr/archs/gmflow/environment.yml
@@ -0,0 +1,162 @@
+name: gmflow
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=4.5=1_gnu
+ - blas=1.0=mkl
+ - bottleneck=1.3.2=py38heb32a55_1
+ - brotli=1.0.9=he6710b0_2
+ - bzip2=1.0.8=h7b6447c_0
+ - ca-certificates=2021.10.26=h06a4308_2
+ - certifi=2021.10.8=py38h06a4308_2
+ - cudatoolkit=10.2.89=hfd86e86_1
+ - cycler=0.10.0=py38_0
+ - dbus=1.13.18=hb2f20db_0
+ - expat=2.4.1=h2531618_2
+ - ffmpeg=4.3=hf484d3e_0
+ - fontconfig=2.13.1=h6c09931_0
+ - fonttools=4.25.0=pyhd3eb1b0_0
+ - freetype=2.10.4=h5ab3b9f_0
+ - glib=2.69.0=h5202010_0
+ - gmp=6.2.1=h2531618_2
+ - gnutls=3.6.15=he1e5248_0
+ - gst-plugins-base=1.14.0=h8213a91_2
+ - gstreamer=1.14.0=h28cd5cc_2
+ - icu=58.2=he6710b0_3
+ - imageio=2.9.0=pyhd3eb1b0_0
+ - intel-openmp=2021.3.0=h06a4308_3350
+ - jpeg=9b=h024ee3a_2
+ - kiwisolver=1.3.1=py38h2531618_0
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.35.1=h7274673_9
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=9.3.0=h5101ec6_17
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
+ - libgfortran4=7.5.0=ha8ba4b0_17
+ - libgomp=9.3.0=h5101ec6_17
+ - libiconv=1.15=h63c8f33_5
+ - libidn2=2.3.2=h7f8727e_0
+ - libpng=1.6.37=hbc83047_0
+ - libstdcxx-ng=9.3.0=hd4cf53a_17
+ - libtasn1=4.16.0=h27cfd23_0
+ - libtiff=4.2.0=h85742a9_0
+ - libunistring=0.9.10=h27cfd23_0
+ - libuuid=1.0.3=h1bed415_2
+ - libuv=1.40.0=h7b6447c_0
+ - libwebp-base=1.2.0=h27cfd23_0
+ - libxcb=1.14=h7b6447c_0
+ - libxml2=2.9.12=h03d6c58_0
+ - lz4-c=1.9.3=h2531618_0
+ - matplotlib=3.4.2=py38h06a4308_0
+ - matplotlib-base=3.4.2=py38hab158f2_0
+ - mkl=2021.3.0=h06a4308_520
+ - mkl-service=2.4.0=py38h7f8727e_0
+ - mkl_fft=1.3.0=py38h42c9631_2
+ - mkl_random=1.2.2=py38h51133e4_0
+ - munkres=1.1.4=py_0
+ - ncurses=6.2=he6710b0_1
+ - nettle=3.7.3=hbbd107a_1
+ - ninja=1.10.2=hff7bd54_1
+ - numexpr=2.7.3=py38h22e1b3c_1
+ - numpy=1.20.3=py38hf144106_0
+ - numpy-base=1.20.3=py38h74d4b33_0
+ - olefile=0.46=py_0
+ - openh264=2.1.0=hd408876_0
+ - openjpeg=2.3.0=h05c96fa_1
+ - openssl=1.1.1m=h7f8727e_0
+ - pandas=1.3.2=py38h8c16a72_0
+ - pcre=8.45=h295c915_0
+ - pillow=8.3.1=py38h2c7a002_0
+ - pip=21.2.2=py38h06a4308_0
+ - pyparsing=2.4.7=pyhd3eb1b0_0
+ - pyqt=5.9.2=py38h05f1152_4
+ - python=3.8.11=h12debd9_0_cpython
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
+ - pytorch=1.9.0=py3.8_cuda10.2_cudnn7.6.5_0
+ - pytz=2021.1=pyhd3eb1b0_0
+ - qt=5.9.7=h5867ecd_1
+ - readline=8.1=h27cfd23_0
+ - scipy=1.6.2=py38had2a1c9_1
+ - seaborn=0.11.2=pyhd3eb1b0_0
+ - setuptools=52.0.0=py38h06a4308_0
+ - sip=4.19.13=py38he6710b0_0
+ - six=1.16.0=pyhd3eb1b0_0
+ - sqlite=3.36.0=hc218d9a_0
+ - tk=8.6.10=hbc83047_0
+ - torchaudio=0.9.0=py38
+ - torchvision=0.10.0=py38_cu102
+ - tornado=6.1=py38h27cfd23_0
+ - typing_extensions=3.10.0.0=pyh06a4308_0
+ - wheel=0.36.2=pyhd3eb1b0_0
+ - xz=5.2.5=h7b6447c_0
+ - zlib=1.2.11=h7b6447c_3
+ - zstd=1.4.9=haebb681_0
+ - pip:
+ - absl-py==0.13.0
+ - argon2-cffi==21.1.0
+ - attrs==21.2.0
+ - backcall==0.2.0
+ - bleach==4.1.0
+ - cachetools==4.2.2
+ - cffi==1.14.6
+ - charset-normalizer==2.0.4
+ - debugpy==1.4.3
+ - decorator==5.1.0
+ - defusedxml==0.7.1
+ - einops==0.3.2
+ - entrypoints==0.3
+ - google-auth==1.34.0
+ - google-auth-oauthlib==0.4.5
+ - grpcio==1.39.0
+ - idna==3.2
+ - ipykernel==6.4.1
+ - ipython==7.27.0
+ - ipython-genutils==0.2.0
+ - jedi==0.18.0
+ - jinja2==3.0.1
+ - jsonschema==3.2.0
+ - jupyter-client==7.0.3
+ - jupyter-core==4.8.1
+ - jupyterlab-pygments==0.1.2
+ - markdown==3.3.4
+ - markupsafe==2.0.1
+ - matplotlib-inline==0.1.3
+ - mistune==0.8.4
+ - nbclient==0.5.4
+ - nbconvert==6.1.0
+ - nbformat==5.1.3
+ - nest-asyncio==1.5.1
+ - oauthlib==3.1.1
+ - opencv-python==4.5.3.56
+ - packaging==21.0
+ - pandocfilters==1.5.0
+ - parso==0.8.2
+ - pexpect==4.8.0
+ - pickleshare==0.7.5
+ - prometheus-client==0.11.0
+ - prompt-toolkit==3.0.20
+ - protobuf==3.17.3
+ - ptyprocess==0.7.0
+ - pyasn1==0.4.8
+ - pyasn1-modules==0.2.8
+ - pycparser==2.20
+ - pygments==2.10.0
+ - pyrsistent==0.18.0
+ - pyzmq==22.3.0
+ - requests==2.26.0
+ - requests-oauthlib==1.3.0
+ - rsa==4.7.2
+ - send2trash==1.8.0
+ - tensorboard==2.5.0
+ - tensorboard-data-server==0.6.1
+ - tensorboard-plugin-wit==1.8.0
+ - terminado==0.12.1
+ - testpath==0.5.0
+ - traitlets==5.1.0
+ - urllib3==1.26.6
+ - wcwidth==0.2.5
+ - webencodings==0.5.1
+ - werkzeug==2.0.1
diff --git a/basicsr/archs/gmflow/evaluate.py b/basicsr/archs/gmflow/evaluate.py
new file mode 100755
index 0000000000000000000000000000000000000000..be6a3f53d009c843e05a6a5dd75ec9034788b29b
--- /dev/null
+++ b/basicsr/archs/gmflow/evaluate.py
@@ -0,0 +1,689 @@
+from PIL import Image
+import os
+import time
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+import data
+from utils import frame_utils
+from utils.flow_viz import save_vis_flow_tofile
+
+from utils.utils import InputPadder, compute_out_of_boundary_mask
+from glob import glob
+from gmflow.geometry import forward_backward_consistency_check
+
+
+@torch.no_grad()
+def create_sintel_submission(model,
+ output_path='sintel_submission',
+ padding_factor=8,
+ save_vis_flow=False,
+ no_save_flo=False,
+ attn_splits_list=None,
+ corr_radius_list=None,
+ prop_radius_list=None,
+ ):
+ """ Create submission for the Sintel leaderboard """
+ model.eval()
+ for dstype in ['clean', 'final']:
+ test_dataset = data.MpiSintel(split='test', aug_params=None, dstype=dstype)
+
+ flow_prev, sequence_prev = None, None
+ for test_id in range(len(test_dataset)):
+ image1, image2, (sequence, frame) = test_dataset[test_id]
+ if sequence != sequence_prev:
+ flow_prev = None
+
+ padder = InputPadder(image1.shape, padding_factor=padding_factor)
+ image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
+
+ results_dict = model(image1, image2,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ )
+
+ flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W]
+
+ flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
+
+ output_dir = os.path.join(output_path, dstype, sequence)
+ output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame + 1))
+
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ if not no_save_flo:
+ frame_utils.writeFlow(output_file, flow)
+ sequence_prev = sequence
+
+ # Save vis flow
+ if save_vis_flow:
+ vis_flow_file = output_file.replace('.flo', '.png')
+ save_vis_flow_tofile(flow, vis_flow_file)
+
+
+@torch.no_grad()
+def create_kitti_submission(model,
+ output_path='kitti_submission',
+ padding_factor=8,
+ save_vis_flow=False,
+ attn_splits_list=None,
+ corr_radius_list=None,
+ prop_radius_list=None,
+ ):
+ """ Create submission for the Sintel leaderboard """
+ model.eval()
+ test_dataset = data.KITTI(split='testing', aug_params=None)
+
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+
+ for test_id in range(len(test_dataset)):
+ image1, image2, (frame_id,) = test_dataset[test_id]
+ padder = InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor)
+ image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
+
+ results_dict = model(image1, image2,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ )
+
+ flow_pr = results_dict['flow_preds'][-1]
+
+ flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
+
+ output_filename = os.path.join(output_path, frame_id)
+
+ if save_vis_flow:
+ vis_flow_file = output_filename
+ save_vis_flow_tofile(flow, vis_flow_file)
+ else:
+ frame_utils.writeFlowKITTI(output_filename, flow)
+
+
+@torch.no_grad()
+def validate_chairs(model,
+ with_speed_metric=False,
+ attn_splits_list=False,
+ corr_radius_list=False,
+ prop_radius_list=False,
+ ):
+ """ Perform evaluation on the FlyingChairs (test) split """
+ model.eval()
+ epe_list = []
+ results = {}
+
+ if with_speed_metric:
+ s0_10_list = []
+ s10_40_list = []
+ s40plus_list = []
+
+ val_dataset = data.FlyingChairs(split='validation')
+
+ print('Number of validation image pairs: %d' % len(val_dataset))
+
+ for val_id in range(len(val_dataset)):
+ image1, image2, flow_gt, _ = val_dataset[val_id]
+
+ image1 = image1[None].cuda()
+ image2 = image2[None].cuda()
+
+ results_dict = model(image1, image2,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ )
+
+ flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W]
+
+ assert flow_pr.size()[-2:] == flow_gt.size()[-2:]
+
+ epe = torch.sum((flow_pr[0].cpu() - flow_gt) ** 2, dim=0).sqrt()
+ epe_list.append(epe.view(-1).numpy())
+
+ if with_speed_metric:
+ flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt()
+ valid_mask = (flow_gt_speed < 10)
+ if valid_mask.max() > 0:
+ s0_10_list.append(epe[valid_mask].cpu().numpy())
+
+ valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40)
+ if valid_mask.max() > 0:
+ s10_40_list.append(epe[valid_mask].cpu().numpy())
+
+ valid_mask = (flow_gt_speed > 40)
+ if valid_mask.max() > 0:
+ s40plus_list.append(epe[valid_mask].cpu().numpy())
+
+ epe_all = np.concatenate(epe_list)
+ epe = np.mean(epe_all)
+ px1 = np.mean(epe_all > 1)
+ px3 = np.mean(epe_all > 3)
+ px5 = np.mean(epe_all > 5)
+ print("Validation Chairs EPE: %.3f, 1px: %.3f, 3px: %.3f, 5px: %.3f" % (epe, px1, px3, px5))
+ results['chairs_epe'] = epe
+ results['chairs_1px'] = px1
+ results['chairs_3px'] = px3
+ results['chairs_5px'] = px5
+
+ if with_speed_metric:
+ s0_10 = np.mean(np.concatenate(s0_10_list))
+ s10_40 = np.mean(np.concatenate(s10_40_list))
+ s40plus = np.mean(np.concatenate(s40plus_list))
+
+ print("Validation Chairs s0_10: %.3f, s10_40: %.3f, s40+: %.3f" % (
+ s0_10,
+ s10_40,
+ s40plus))
+
+ results['chairs_s0_10'] = s0_10
+ results['chairs_s10_40'] = s10_40
+ results['chairs_s40+'] = s40plus
+
+ return results
+
+
+@torch.no_grad()
+def validate_things(model,
+ padding_factor=8,
+ with_speed_metric=False,
+ max_val_flow=400,
+ val_things_clean_only=True,
+ attn_splits_list=False,
+ corr_radius_list=False,
+ prop_radius_list=False,
+ ):
+ """ Peform validation using the Things (test) split """
+ model.eval()
+ results = {}
+
+ for dstype in ['frames_cleanpass', 'frames_finalpass']:
+ if val_things_clean_only:
+ if dstype == 'frames_finalpass':
+ continue
+
+ val_dataset = data.FlyingThings3D(dstype=dstype, test_set=True, validate_subset=True,
+ )
+ print('Number of validation image pairs: %d' % len(val_dataset))
+ epe_list = []
+
+ if with_speed_metric:
+ s0_10_list = []
+ s10_40_list = []
+ s40plus_list = []
+
+ for val_id in range(len(val_dataset)):
+ image1, image2, flow_gt, valid_gt = val_dataset[val_id]
+ image1 = image1[None].cuda()
+ image2 = image2[None].cuda()
+
+ padder = InputPadder(image1.shape, padding_factor=padding_factor)
+ image1, image2 = padder.pad(image1, image2)
+
+ results_dict = model(image1, image2,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ )
+ flow_pr = results_dict['flow_preds'][-1]
+
+ flow = padder.unpad(flow_pr[0]).cpu()
+
+ # Evaluation on flow <= max_val_flow
+ flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt()
+ valid_gt = valid_gt * (flow_gt_speed < max_val_flow)
+ valid_gt = valid_gt.contiguous()
+
+ epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt()
+ val = valid_gt >= 0.5
+ epe_list.append(epe[val].cpu().numpy())
+
+ if with_speed_metric:
+ valid_mask = (flow_gt_speed < 10) * (valid_gt >= 0.5)
+ if valid_mask.max() > 0:
+ s0_10_list.append(epe[valid_mask].cpu().numpy())
+
+ valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40) * (valid_gt >= 0.5)
+ if valid_mask.max() > 0:
+ s10_40_list.append(epe[valid_mask].cpu().numpy())
+
+ valid_mask = (flow_gt_speed > 40) * (valid_gt >= 0.5)
+ if valid_mask.max() > 0:
+ s40plus_list.append(epe[valid_mask].cpu().numpy())
+
+ epe_list = np.mean(np.concatenate(epe_list))
+
+ epe = np.mean(epe_list)
+
+ if dstype == 'frames_cleanpass':
+ dstype = 'things_clean'
+ if dstype == 'frames_finalpass':
+ dstype = 'things_final'
+
+ print("Validation Things test set (%s) EPE: %.3f" % (dstype, epe))
+ results[dstype + '_epe'] = epe
+
+ if with_speed_metric:
+ s0_10 = np.mean(np.concatenate(s0_10_list))
+ s10_40 = np.mean(np.concatenate(s10_40_list))
+ s40plus = np.mean(np.concatenate(s40plus_list))
+
+ print("Validation Things test (%s) s0_10: %.3f, s10_40: %.3f, s40+: %.3f" % (
+ dstype, s0_10,
+ s10_40,
+ s40plus))
+
+ results[dstype + '_s0_10'] = s0_10
+ results[dstype + '_s10_40'] = s10_40
+ results[dstype + '_s40+'] = s40plus
+
+ return results
+
+
+@torch.no_grad()
+def validate_sintel(model,
+ count_time=False,
+ padding_factor=8,
+ with_speed_metric=False,
+ evaluate_matched_unmatched=False,
+ attn_splits_list=False,
+ corr_radius_list=False,
+ prop_radius_list=False,
+ ):
+ """ Peform validation using the Sintel (train) split """
+ model.eval()
+ results = {}
+
+ if count_time:
+ total_time = 0
+ num_runs = 100
+
+ for dstype in ['clean', 'final']:
+ val_dataset = data.MpiSintel(split='training', dstype=dstype,
+ load_occlusion=evaluate_matched_unmatched,
+ )
+
+ print('Number of validation image pairs: %d' % len(val_dataset))
+ epe_list = []
+
+ if evaluate_matched_unmatched:
+ matched_epe_list = []
+ unmatched_epe_list = []
+
+ if with_speed_metric:
+ s0_10_list = []
+ s10_40_list = []
+ s40plus_list = []
+
+ for val_id in range(len(val_dataset)):
+ if evaluate_matched_unmatched:
+ image1, image2, flow_gt, valid, noc_valid = val_dataset[val_id]
+
+ # compuate in-image-plane valid mask
+ in_image_valid = compute_out_of_boundary_mask(flow_gt.unsqueeze(0)).squeeze(0) # [H, W]
+
+ else:
+ image1, image2, flow_gt, _ = val_dataset[val_id]
+
+ image1 = image1[None].cuda()
+ image2 = image2[None].cuda()
+
+ padder = InputPadder(image1.shape, padding_factor=padding_factor)
+ image1, image2 = padder.pad(image1, image2)
+
+ if count_time and val_id >= 5: # 5 warmup
+ torch.cuda.synchronize()
+ time_start = time.perf_counter()
+
+ results_dict = model(image1, image2,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ )
+
+ # useful when using parallel branches
+ flow_pr = results_dict['flow_preds'][-1]
+
+ if count_time and val_id >= 5:
+ torch.cuda.synchronize()
+ total_time += time.perf_counter() - time_start
+
+ if val_id >= num_runs + 4:
+ break
+
+ flow = padder.unpad(flow_pr[0]).cpu()
+
+ epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt()
+ epe_list.append(epe.view(-1).numpy())
+
+ if evaluate_matched_unmatched:
+ matched_valid_mask = (noc_valid > 0.5) & (in_image_valid > 0.5)
+
+ if matched_valid_mask.max() > 0:
+ matched_epe_list.append(epe[matched_valid_mask].cpu().numpy())
+ unmatched_epe_list.append(epe[~matched_valid_mask].cpu().numpy())
+
+ if with_speed_metric:
+ flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt()
+ valid_mask = (flow_gt_speed < 10)
+ if valid_mask.max() > 0:
+ s0_10_list.append(epe[valid_mask].cpu().numpy())
+
+ valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40)
+ if valid_mask.max() > 0:
+ s10_40_list.append(epe[valid_mask].cpu().numpy())
+
+ valid_mask = (flow_gt_speed > 40)
+ if valid_mask.max() > 0:
+ s40plus_list.append(epe[valid_mask].cpu().numpy())
+
+ epe_all = np.concatenate(epe_list)
+ epe = np.mean(epe_all)
+ px1 = np.mean(epe_all > 1)
+ px3 = np.mean(epe_all > 3)
+ px5 = np.mean(epe_all > 5)
+
+ dstype_ori = dstype
+
+ print("Validation Sintel (%s) EPE: %.3f, 1px: %.3f, 3px: %.3f, 5px: %.3f" % (dstype_ori, epe, px1, px3, px5))
+
+ dstype = 'sintel_' + dstype
+
+ results[dstype + '_epe'] = np.mean(epe_list)
+ results[dstype + '_1px'] = px1
+ results[dstype + '_3px'] = px3
+ results[dstype + '_5px'] = px5
+
+ if with_speed_metric:
+ s0_10 = np.mean(np.concatenate(s0_10_list))
+ s10_40 = np.mean(np.concatenate(s10_40_list))
+ s40plus = np.mean(np.concatenate(s40plus_list))
+
+ print("Validation Sintel (%s) s0_10: %.3f, s10_40: %.3f, s40+: %.3f" % (
+ dstype_ori, s0_10,
+ s10_40,
+ s40plus))
+
+ results[dstype + '_s0_10'] = s0_10
+ results[dstype + '_s10_40'] = s10_40
+ results[dstype + '_s40+'] = s40plus
+
+ if count_time:
+ print('Time: %.6fs' % (total_time / num_runs))
+ break # only the clean pass when counting time
+
+ if evaluate_matched_unmatched:
+ matched_epe = np.mean(np.concatenate(matched_epe_list))
+ unmatched_epe = np.mean(np.concatenate(unmatched_epe_list))
+
+ print('Validatation Sintel (%s) matched epe: %.3f, unmatched epe: %.3f' % (
+ dstype_ori, matched_epe, unmatched_epe))
+
+ results[dstype + '_matched'] = matched_epe
+ results[dstype + '_unmatched'] = unmatched_epe
+
+ return results
+
+
+@torch.no_grad()
+def validate_kitti(model,
+ padding_factor=8,
+ with_speed_metric=False,
+ average_over_pixels=True,
+ attn_splits_list=False,
+ corr_radius_list=False,
+ prop_radius_list=False,
+ ):
+ """ Peform validation using the KITTI-2015 (train) split """
+ model.eval()
+
+ val_dataset = data.KITTI(split='training')
+ print('Number of validation image pairs: %d' % len(val_dataset))
+
+ out_list, epe_list = [], []
+ results = {}
+
+ if with_speed_metric:
+ if average_over_pixels:
+ s0_10_list = []
+ s10_40_list = []
+ s40plus_list = []
+ else:
+ s0_10_epe_sum = 0
+ s0_10_valid_samples = 0
+ s10_40_epe_sum = 0
+ s10_40_valid_samples = 0
+ s40plus_epe_sum = 0
+ s40plus_valid_samples = 0
+
+ for val_id in range(len(val_dataset)):
+ image1, image2, flow_gt, valid_gt = val_dataset[val_id]
+ image1 = image1[None].cuda()
+ image2 = image2[None].cuda()
+
+ padder = InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor)
+ image1, image2 = padder.pad(image1, image2)
+
+ results_dict = model(image1, image2,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ )
+
+ # useful when using parallel branches
+ flow_pr = results_dict['flow_preds'][-1]
+
+ flow = padder.unpad(flow_pr[0]).cpu()
+
+ epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt()
+ mag = torch.sum(flow_gt ** 2, dim=0).sqrt()
+
+ if with_speed_metric:
+ # flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt()
+ flow_gt_speed = mag
+
+ if average_over_pixels:
+ valid_mask = (flow_gt_speed < 10) * (valid_gt >= 0.5) # note KITTI GT is sparse
+ if valid_mask.max() > 0:
+ s0_10_list.append(epe[valid_mask].cpu().numpy())
+
+ valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40) * (valid_gt >= 0.5)
+ if valid_mask.max() > 0:
+ s10_40_list.append(epe[valid_mask].cpu().numpy())
+
+ valid_mask = (flow_gt_speed > 40) * (valid_gt >= 0.5)
+ if valid_mask.max() > 0:
+ s40plus_list.append(epe[valid_mask].cpu().numpy())
+
+ else:
+ valid_mask = (flow_gt_speed < 10) * (valid_gt >= 0.5) # note KITTI GT is sparse
+ if valid_mask.max() > 0:
+ s0_10_epe_sum += (epe * valid_mask).sum() / valid_mask.sum()
+ s0_10_valid_samples += 1
+
+ valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40) * (valid_gt >= 0.5)
+ if valid_mask.max() > 0:
+ s10_40_epe_sum += (epe * valid_mask).sum() / valid_mask.sum()
+ s10_40_valid_samples += 1
+
+ valid_mask = (flow_gt_speed > 40) * (valid_gt >= 0.5)
+ if valid_mask.max() > 0:
+ s40plus_epe_sum += (epe * valid_mask).sum() / valid_mask.sum()
+ s40plus_valid_samples += 1
+
+ epe = epe.view(-1)
+ mag = mag.view(-1)
+ val = valid_gt.view(-1) >= 0.5
+
+ out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
+
+ if average_over_pixels:
+ epe_list.append(epe[val].cpu().numpy())
+ else:
+ epe_list.append(epe[val].mean().item())
+
+ out_list.append(out[val].cpu().numpy())
+
+ if average_over_pixels:
+ epe_list = np.concatenate(epe_list)
+ else:
+ epe_list = np.array(epe_list)
+ out_list = np.concatenate(out_list)
+
+ epe = np.mean(epe_list)
+ f1 = 100 * np.mean(out_list)
+
+ print("Validation KITTI EPE: %.3f, F1-all: %.3f" % (epe, f1))
+ results['kitti_epe'] = epe
+ results['kitti_f1'] = f1
+
+ if with_speed_metric:
+ if average_over_pixels:
+ s0_10 = np.mean(np.concatenate(s0_10_list))
+ s10_40 = np.mean(np.concatenate(s10_40_list))
+ s40plus = np.mean(np.concatenate(s40plus_list))
+ else:
+ s0_10 = s0_10_epe_sum / s0_10_valid_samples
+ s10_40 = s10_40_epe_sum / s10_40_valid_samples
+ s40plus = s40plus_epe_sum / s40plus_valid_samples
+
+ print("Validation KITTI s0_10: %.3f, s10_40: %.3f, s40+: %.3f" % (
+ s0_10,
+ s10_40,
+ s40plus))
+
+ results['kitti_s0_10'] = s0_10
+ results['kitti_s10_40'] = s10_40
+ results['kitti_s40+'] = s40plus
+
+ return results
+
+
+@torch.no_grad()
+def inference_on_dir(model,
+ inference_dir,
+ output_path='output',
+ padding_factor=8,
+ inference_size=None,
+ paired_data=False, # dir of paired testdata instead of a sequence
+ save_flo_flow=False, # save as .flo for quantative evaluation
+ attn_splits_list=None,
+ corr_radius_list=None,
+ prop_radius_list=None,
+ pred_bidir_flow=False,
+ fwd_bwd_consistency_check=False,
+ ):
+ """ Inference on a directory """
+ model.eval()
+
+ if fwd_bwd_consistency_check:
+ assert pred_bidir_flow
+
+ if not os.path.exists(output_path):
+ os.makedirs(output_path)
+
+ filenames = sorted(glob(inference_dir + '/*'))
+ print('%d images found' % len(filenames))
+
+ stride = 2 if paired_data else 1
+
+ if paired_data:
+ assert len(filenames) % 2 == 0
+
+ for test_id in range(0, len(filenames) - 1, stride):
+
+ image1 = frame_utils.read_gen(filenames[test_id])
+ image2 = frame_utils.read_gen(filenames[test_id + 1])
+
+ image1 = np.array(image1).astype(np.uint8)
+ image2 = np.array(image2).astype(np.uint8)
+
+ if len(image1.shape) == 2: # gray image, for example, HD1K
+ image1 = np.tile(image1[..., None], (1, 1, 3))
+ image2 = np.tile(image2[..., None], (1, 1, 3))
+ else:
+ image1 = image1[..., :3]
+ image2 = image2[..., :3]
+
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
+ image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
+
+ if inference_size is None:
+ padder = InputPadder(image1.shape, padding_factor=padding_factor)
+ image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
+ else:
+ image1, image2 = image1[None].cuda(), image2[None].cuda()
+
+ # resize before inference
+ if inference_size is not None:
+ assert isinstance(inference_size, list) or isinstance(inference_size, tuple)
+ ori_size = image1.shape[-2:]
+ image1 = F.interpolate(image1, size=inference_size, mode='bilinear',
+ align_corners=True)
+ image2 = F.interpolate(image2, size=inference_size, mode='bilinear',
+ align_corners=True)
+
+ results_dict = model(image1, image2,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ pred_bidir_flow=pred_bidir_flow,
+ )
+
+ flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W]
+
+ # resize back
+ if inference_size is not None:
+ flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear',
+ align_corners=True)
+ flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1]
+ flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2]
+
+ if inference_size is None:
+ flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() # [H, W, 2]
+ else:
+ flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2]
+
+ output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_flow.png')
+
+ # save vis flow
+ save_vis_flow_tofile(flow, output_file)
+
+ # also predict backward flow
+ if pred_bidir_flow:
+ assert flow_pr.size(0) == 2 # [2, H, W, 2]
+
+ if inference_size is None:
+ flow_bwd = padder.unpad(flow_pr[1]).permute(1, 2, 0).cpu().numpy() # [H, W, 2]
+ else:
+ flow_bwd = flow_pr[1].permute(1, 2, 0).cpu().numpy() # [H, W, 2]
+
+ output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_flow_bwd.png')
+
+ # save vis flow
+ save_vis_flow_tofile(flow_bwd, output_file)
+
+ # forward-backward consistency check
+ # occlusion is 1
+ if fwd_bwd_consistency_check:
+ if inference_size is None:
+ fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W]
+ bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W]
+ else:
+ fwd_flow = flow_pr[0].unsqueeze(0)
+ bwd_flow = flow_pr[1].unsqueeze(0)
+
+ fwd_occ, bwd_occ = forward_backward_consistency_check(fwd_flow, bwd_flow) # [1, H, W] float
+
+ fwd_occ_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_occ.png')
+ bwd_occ_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_occ_bwd.png')
+
+ Image.fromarray((fwd_occ[0].cpu().numpy() * 255.).astype(np.uint8)).save(fwd_occ_file)
+ Image.fromarray((bwd_occ[0].cpu().numpy() * 255.).astype(np.uint8)).save(bwd_occ_file)
+
+ if save_flo_flow:
+ output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_pred.flo')
+ frame_utils.writeFlow(output_file, flow)
diff --git a/basicsr/archs/gmflow/gmflow/__init__.py b/basicsr/archs/gmflow/gmflow/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/basicsr/archs/gmflow/gmflow/backbone.py b/basicsr/archs/gmflow/gmflow/backbone.py
new file mode 100755
index 0000000000000000000000000000000000000000..d5c92b7d8698a41d11b29f084b3ab4953dd2a7bd
--- /dev/null
+++ b/basicsr/archs/gmflow/gmflow/backbone.py
@@ -0,0 +1,117 @@
+import torch.nn as nn
+
+from .trident_conv import MultiScaleTridentConv
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
+ ):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
+ dilation=dilation, padding=dilation, stride=stride, bias=False)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+ dilation=dilation, padding=dilation, bias=False)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.norm1 = norm_layer(planes)
+ self.norm2 = norm_layer(planes)
+ if not stride == 1 or in_planes != planes:
+ self.norm3 = norm_layer(planes)
+
+ if stride == 1 and in_planes == planes:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class CNNEncoder(nn.Module):
+ def __init__(self, output_dim=128,
+ norm_layer=nn.InstanceNorm2d,
+ num_output_scales=1,
+ **kwargs,
+ ):
+ super(CNNEncoder, self).__init__()
+ self.num_branch = num_output_scales
+
+ feature_dims = [64, 96, 128]
+
+ self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
+ self.norm1 = norm_layer(feature_dims[0])
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = feature_dims[0]
+ self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
+ self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
+
+ # highest resolution 1/4 or 1/8
+ stride = 2 if num_output_scales == 1 else 1
+ self.layer3 = self._make_layer(feature_dims[2], stride=stride,
+ norm_layer=norm_layer,
+ ) # 1/4 or 1/8
+
+ self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
+
+ if self.num_branch > 1:
+ if self.num_branch == 4:
+ strides = (1, 2, 4, 8)
+ elif self.num_branch == 3:
+ strides = (1, 2, 4)
+ elif self.num_branch == 2:
+ strides = (1, 2)
+ else:
+ raise ValueError
+
+ self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
+ kernel_size=3,
+ strides=strides,
+ paddings=1,
+ num_branch=self.num_branch,
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
+ layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
+ layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
+
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x) # 1/2
+ x = self.layer2(x) # 1/4
+ x = self.layer3(x) # 1/8 or 1/4
+
+ x = self.conv2(x)
+
+ if self.num_branch > 1:
+ out = self.trident_conv([x] * self.num_branch) # high to low res
+ else:
+ out = [x]
+
+ return out
diff --git a/basicsr/archs/gmflow/gmflow/geometry.py b/basicsr/archs/gmflow/gmflow/geometry.py
new file mode 100755
index 0000000000000000000000000000000000000000..207e98fded56c0e7e63d63626ddace65b910bf9c
--- /dev/null
+++ b/basicsr/archs/gmflow/gmflow/geometry.py
@@ -0,0 +1,96 @@
+import torch
+import torch.nn.functional as F
+
+
+def coords_grid(b, h, w, homogeneous=False, device=None):
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
+
+ stacks = [x, y]
+
+ if homogeneous:
+ ones = torch.ones_like(x) # [H, W]
+ stacks.append(ones)
+
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
+
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
+
+ if device is not None:
+ grid = grid.to(device)
+
+ return grid
+
+
+def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
+ assert device is not None
+
+ x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
+ torch.linspace(h_min, h_max, len_h, device=device)],
+ )
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
+
+ return grid
+
+
+def normalize_coords(coords, h, w):
+ # coords: [B, H, W, 2]
+ c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
+ return (coords - c) / c # [-1, 1]
+
+
+def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
+ # img: [B, C, H, W]
+ # sample_coords: [B, 2, H, W] in image scale
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
+
+ b, _, h, w = sample_coords.shape
+
+ # Normalize to [-1, 1]
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
+
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
+
+ img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
+
+ if return_mask:
+ mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
+
+ return img, mask
+
+ return img
+
+
+def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
+ b, c, h, w = feature.size()
+ assert flow.size(1) == 2
+
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
+
+ return bilinear_sample(feature, grid, padding_mode=padding_mode,
+ return_mask=mask)
+
+
+def forward_backward_consistency_check(fwd_flow, bwd_flow,
+ alpha=0.01,
+ beta=0.5
+ ):
+ # fwd_flow, bwd_flow: [B, 2, H, W]
+ # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
+
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
+
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
+
+ threshold = alpha * flow_mag + beta
+
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
+ bwd_occ = (diff_bwd > threshold).float()
+
+ return fwd_occ, bwd_occ
diff --git a/basicsr/archs/gmflow/gmflow/gmflow.py b/basicsr/archs/gmflow/gmflow/gmflow.py
new file mode 100755
index 0000000000000000000000000000000000000000..7191df080824317e43134680975dd991360d2f79
--- /dev/null
+++ b/basicsr/archs/gmflow/gmflow/gmflow.py
@@ -0,0 +1,170 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .backbone import CNNEncoder
+from .transformer import FeatureTransformer, FeatureFlowAttention
+from .matching import global_correlation_softmax, local_correlation_softmax
+from .geometry import flow_warp
+from .utils import normalize_img, feature_add_position
+
+
+class GMFlow(nn.Module):
+ def __init__(self,
+ num_scales=1,
+ upsample_factor=8,
+ feature_channels=128,
+ attention_type='swin',
+ num_transformer_layers=6,
+ ffn_dim_expansion=4,
+ num_head=1,
+ **kwargs,
+ ):
+ super(GMFlow, self).__init__()
+
+ self.num_scales = num_scales
+ self.feature_channels = feature_channels
+ self.upsample_factor = upsample_factor
+ self.attention_type = attention_type
+ self.num_transformer_layers = num_transformer_layers
+
+ # CNN backbone
+ self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
+
+ # Transformer
+ self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
+ d_model=feature_channels,
+ nhead=num_head,
+ attention_type=attention_type,
+ ffn_dim_expansion=ffn_dim_expansion,
+ )
+
+ # flow propagation with self-attn
+ self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
+
+ # convex upsampling: concat feature0 and flow as input
+ self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
+
+ def extract_feature(self, img0, img1):
+ concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
+ features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
+
+ # reverse: resolution from low to high
+ features = features[::-1]
+
+ feature0, feature1 = [], []
+
+ for i in range(len(features)):
+ feature = features[i]
+ chunks = torch.chunk(feature, 2, 0) # tuple
+ feature0.append(chunks[0])
+ feature1.append(chunks[1])
+
+ return feature0, feature1
+
+ def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
+ ):
+ if bilinear:
+ up_flow = F.interpolate(flow, scale_factor=upsample_factor,
+ mode='bilinear', align_corners=True) * upsample_factor
+
+ else:
+ # convex upsampling
+ concat = torch.cat((flow, feature), dim=1)
+
+ mask = self.upsampler(concat)
+ b, flow_channel, h, w = flow.shape
+ mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
+ mask = torch.softmax(mask, dim=2)
+
+ up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
+ up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
+
+ up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
+ up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h,
+ self.upsample_factor * w) # [B, 2, K*H, K*W]
+
+ return up_flow
+
+ def forward(self, img0, img1,
+ attn_splits_list=None,
+ corr_radius_list=None,
+ prop_radius_list=None,
+ pred_bidir_flow=False,
+ **kwargs,
+ ):
+
+ results_dict = {}
+ flow_preds = []
+
+ img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
+
+ # resolution low to high
+ feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
+
+ flow = None
+
+ assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
+
+ for scale_idx in range(self.num_scales):
+ feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
+
+ if pred_bidir_flow and scale_idx > 0:
+ # predicting bidirectional flow with refinement
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
+
+ upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
+
+ if scale_idx > 0:
+ flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
+
+ if flow is not None:
+ flow = flow.detach()
+ feature1 = flow_warp(feature1, flow) # [B, C, H, W]
+
+ attn_splits = attn_splits_list[scale_idx]
+ corr_radius = corr_radius_list[scale_idx]
+ prop_radius = prop_radius_list[scale_idx]
+
+ # add position to features
+ feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
+
+ # Transformer
+ feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)
+
+ # correlation and softmax
+ if corr_radius == -1: # global matching
+ flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
+ else: # local matching
+ flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
+
+ # flow or residual flow
+ flow = flow + flow_pred if flow is not None else flow_pred
+
+ # upsample to the original resolution for supervison
+ if self.training: # only need to upsample intermediate flow predictions at training time
+ flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor)
+ flow_preds.append(flow_bilinear)
+
+ # flow propagation with self-attn
+ if pred_bidir_flow and scale_idx == 0:
+ feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation
+ flow = self.feature_flow_attn(feature0, flow.detach(),
+ local_window_attn=prop_radius > 0,
+ local_window_radius=prop_radius)
+
+ # bilinear upsampling at training time except the last one
+ if self.training and scale_idx < self.num_scales - 1:
+ flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor)
+ flow_preds.append(flow_up)
+
+ if scale_idx == self.num_scales - 1:
+ flow_up = self.upsample_flow(flow, feature0)
+ flow_preds.append(flow_up)
+
+ results_dict.update({'flow_preds': flow_preds})
+
+ return results_dict
diff --git a/basicsr/archs/gmflow/gmflow/matching.py b/basicsr/archs/gmflow/gmflow/matching.py
new file mode 100755
index 0000000000000000000000000000000000000000..e920081552c3040c95b6a7b55779249f76cbad4b
--- /dev/null
+++ b/basicsr/archs/gmflow/gmflow/matching.py
@@ -0,0 +1,83 @@
+import torch
+import torch.nn.functional as F
+
+from .geometry import coords_grid, generate_window_grid, normalize_coords
+
+
+def global_correlation_softmax(feature0, feature1,
+ pred_bidir_flow=False,
+ ):
+ # global correlation
+ b, c, h, w = feature0.shape
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
+
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
+
+ # flow from softmax
+ init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
+ grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
+
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
+
+ if pred_bidir_flow:
+ correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
+ init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
+ grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
+ b = b * 2
+
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
+
+ correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
+
+ # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
+ flow = correspondence - init_grid
+
+ return flow, prob
+
+
+def local_correlation_softmax(feature0, feature1, local_radius,
+ padding_mode='zeros',
+ ):
+ b, c, h, w = feature0.size()
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
+
+ local_h = 2 * local_radius + 1
+ local_w = 2 * local_radius + 1
+
+ window_grid = generate_window_grid(-local_radius, local_radius,
+ -local_radius, local_radius,
+ local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
+
+ sample_coords_softmax = sample_coords
+
+ # exclude coords that are out of image space
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
+
+ valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
+
+ # normalize coordinates to [-1, 1]
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
+ window_feature = F.grid_sample(feature1, sample_coords_norm,
+ padding_mode=padding_mode, align_corners=True
+ ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
+
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
+
+ # mask invalid locations
+ corr[~valid] = -1e9
+
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
+
+ correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
+ b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
+
+ flow = correspondence - coords_init
+ match_prob = prob
+
+ return flow, match_prob
diff --git a/basicsr/archs/gmflow/gmflow/position.py b/basicsr/archs/gmflow/gmflow/position.py
new file mode 100755
index 0000000000000000000000000000000000000000..42435d0fef24737d3cae7463ca411a635979cf33
--- /dev/null
+++ b/basicsr/archs/gmflow/gmflow/position.py
@@ -0,0 +1,46 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
+
+import torch
+import torch.nn as nn
+import math
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x):
+ # x = tensor_list.tensors # [B, C, H, W]
+ # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
+ b, c, h, w = x.size()
+ mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
+ y_embed = mask.cumsum(1, dtype=torch.float32)
+ x_embed = mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
diff --git a/basicsr/archs/gmflow/gmflow/transformer.py b/basicsr/archs/gmflow/gmflow/transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..dcf657c86959c2b4528c12f698cd6a26874e432f
--- /dev/null
+++ b/basicsr/archs/gmflow/gmflow/transformer.py
@@ -0,0 +1,409 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import split_feature, merge_splits
+
+
+def single_head_full_attention(q, k, v):
+ # q, k, v: [B, L, C]
+ assert q.dim() == k.dim() == v.dim() == 3
+
+ scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
+ attn = torch.softmax(scores, dim=2) # [B, L, L]
+ out = torch.matmul(attn, v) # [B, L, C]
+
+ return out
+
+
+def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
+ shift_size_h, shift_size_w, device=torch.device('cuda')):
+ # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
+ # calculate attention mask for SW-MSA
+ h, w = input_resolution
+ img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
+ h_slices = (slice(0, -window_size_h),
+ slice(-window_size_h, -shift_size_h),
+ slice(-shift_size_h, None))
+ w_slices = (slice(0, -window_size_w),
+ slice(-window_size_w, -shift_size_w),
+ slice(-shift_size_w, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
+
+ mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+
+def single_head_split_window_attention(q, k, v,
+ num_splits=1,
+ with_shift=False,
+ h=None,
+ w=None,
+ attn_mask=None,
+ ):
+ # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
+ # q, k, v: [B, L, C]
+ assert q.dim() == k.dim() == v.dim() == 3
+
+ assert h is not None and w is not None
+ assert q.size(1) == h * w
+
+ b, _, c = q.size()
+
+ b_new = b * num_splits * num_splits
+
+ window_size_h = h // num_splits
+ window_size_w = w // num_splits
+
+ q = q.view(b, h, w, c) # [B, H, W, C]
+ k = k.view(b, h, w, c)
+ v = v.view(b, h, w, c)
+
+ scale_factor = c ** 0.5
+
+ if with_shift:
+ assert attn_mask is not None # compute once
+ shift_size_h = window_size_h // 2
+ shift_size_w = window_size_w // 2
+
+ q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+ k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+ v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
+
+ q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
+ k = split_feature(k, num_splits=num_splits, channel_last=True)
+ v = split_feature(v, num_splits=num_splits, channel_last=True)
+
+ scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
+ ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
+
+ if with_shift:
+ scores += attn_mask.repeat(b, 1, 1)
+
+ attn = torch.softmax(scores, dim=-1)
+
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
+
+ out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
+ num_splits=num_splits, channel_last=True) # [B, H, W, C]
+
+ # shift back
+ if with_shift:
+ out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
+
+ out = out.view(b, -1, c)
+
+ return out
+
+
+class TransformerLayer(nn.Module):
+ def __init__(self,
+ d_model=256,
+ nhead=1,
+ attention_type='swin',
+ no_ffn=False,
+ ffn_dim_expansion=4,
+ with_shift=False,
+ **kwargs,
+ ):
+ super(TransformerLayer, self).__init__()
+
+ self.dim = d_model
+ self.nhead = nhead
+ self.attention_type = attention_type
+ self.no_ffn = no_ffn
+
+ self.with_shift = with_shift
+
+ # multi-head attention
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # no ffn after self-attn, with ffn after cross-attn
+ if not self.no_ffn:
+ in_channels = d_model * 2
+ self.mlp = nn.Sequential(
+ nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
+ nn.GELU(),
+ nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
+ )
+
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, source, target,
+ height=None,
+ width=None,
+ shifted_window_attn_mask=None,
+ attn_num_splits=None,
+ **kwargs,
+ ):
+ # source, target: [B, L, C]
+ query, key, value = source, target, target
+
+ # single-head attention
+ query = self.q_proj(query) # [B, L, C]
+ key = self.k_proj(key) # [B, L, C]
+ value = self.v_proj(value) # [B, L, C]
+
+ if self.attention_type == 'swin' and attn_num_splits > 1:
+ if self.nhead > 1:
+ # we observe that multihead attention slows down the speed and increases the memory consumption
+ # without bringing obvious performance gains and thus the implementation is removed
+ raise NotImplementedError
+ else:
+ message = single_head_split_window_attention(query, key, value,
+ num_splits=attn_num_splits,
+ with_shift=self.with_shift,
+ h=height,
+ w=width,
+ attn_mask=shifted_window_attn_mask,
+ )
+ else:
+ message = single_head_full_attention(query, key, value) # [B, L, C]
+
+ message = self.merge(message) # [B, L, C]
+ message = self.norm1(message)
+
+ if not self.no_ffn:
+ message = self.mlp(torch.cat([source, message], dim=-1))
+ message = self.norm2(message)
+
+ return source + message
+
+
+class TransformerBlock(nn.Module):
+ """self attention + cross attention + FFN"""
+
+ def __init__(self,
+ d_model=256,
+ nhead=1,
+ attention_type='swin',
+ ffn_dim_expansion=4,
+ with_shift=False,
+ **kwargs,
+ ):
+ super(TransformerBlock, self).__init__()
+
+ self.self_attn = TransformerLayer(d_model=d_model,
+ nhead=nhead,
+ attention_type=attention_type,
+ no_ffn=True,
+ ffn_dim_expansion=ffn_dim_expansion,
+ with_shift=with_shift,
+ )
+
+ self.cross_attn_ffn = TransformerLayer(d_model=d_model,
+ nhead=nhead,
+ attention_type=attention_type,
+ ffn_dim_expansion=ffn_dim_expansion,
+ with_shift=with_shift,
+ )
+
+ def forward(self, source, target,
+ height=None,
+ width=None,
+ shifted_window_attn_mask=None,
+ attn_num_splits=None,
+ **kwargs,
+ ):
+ # source, target: [B, L, C]
+
+ # self attention
+ source = self.self_attn(source, source,
+ height=height,
+ width=width,
+ shifted_window_attn_mask=shifted_window_attn_mask,
+ attn_num_splits=attn_num_splits,
+ )
+
+ # cross attention and ffn
+ source = self.cross_attn_ffn(source, target,
+ height=height,
+ width=width,
+ shifted_window_attn_mask=shifted_window_attn_mask,
+ attn_num_splits=attn_num_splits,
+ )
+
+ return source
+
+
+class FeatureTransformer(nn.Module):
+ def __init__(self,
+ num_layers=6,
+ d_model=128,
+ nhead=1,
+ attention_type='swin',
+ ffn_dim_expansion=4,
+ **kwargs,
+ ):
+ super(FeatureTransformer, self).__init__()
+
+ self.attention_type = attention_type
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ self.layers = nn.ModuleList([
+ TransformerBlock(d_model=d_model,
+ nhead=nhead,
+ attention_type=attention_type,
+ ffn_dim_expansion=ffn_dim_expansion,
+ with_shift=True if attention_type == 'swin' and i % 2 == 1 else False,
+ )
+ for i in range(num_layers)])
+
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feature0, feature1,
+ attn_num_splits=None,
+ **kwargs,
+ ):
+
+ b, c, h, w = feature0.shape
+ assert self.d_model == c
+
+ feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
+ feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
+
+ if self.attention_type == 'swin' and attn_num_splits > 1:
+ # global and refine use different number of splits
+ window_size_h = h // attn_num_splits
+ window_size_w = w // attn_num_splits
+
+ # compute attn mask once
+ shifted_window_attn_mask = generate_shift_window_attn_mask(
+ input_resolution=(h, w),
+ window_size_h=window_size_h,
+ window_size_w=window_size_w,
+ shift_size_h=window_size_h // 2,
+ shift_size_w=window_size_w // 2,
+ device=feature0.device,
+ ) # [K*K, H/K*W/K, H/K*W/K]
+ else:
+ shifted_window_attn_mask = None
+
+ # concat feature0 and feature1 in batch dimension to compute in parallel
+ concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
+ concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
+
+ for layer in self.layers:
+ concat0 = layer(concat0, concat1,
+ height=h,
+ width=w,
+ shifted_window_attn_mask=shifted_window_attn_mask,
+ attn_num_splits=attn_num_splits,
+ )
+
+ # update feature1
+ concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
+
+ feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
+
+ # reshape back
+ feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
+ feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
+
+ return feature0, feature1
+
+
+class FeatureFlowAttention(nn.Module):
+ """
+ flow propagation with self-attention on feature
+ query: feature0, key: feature0, value: flow
+ """
+
+ def __init__(self, in_channels,
+ **kwargs,
+ ):
+ super(FeatureFlowAttention, self).__init__()
+
+ self.q_proj = nn.Linear(in_channels, in_channels)
+ self.k_proj = nn.Linear(in_channels, in_channels)
+
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feature0, flow,
+ local_window_attn=False,
+ local_window_radius=1,
+ **kwargs,
+ ):
+ # q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
+ if local_window_attn:
+ return self.forward_local_window_attn(feature0, flow,
+ local_window_radius=local_window_radius)
+
+ b, c, h, w = feature0.size()
+
+ query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
+
+ # a note: the ``correct'' implementation should be:
+ # ``query = self.q_proj(query), key = self.k_proj(query)''
+ # this problem is observed while cleaning up the code
+ # however, this doesn't affect the performance since the projection is a linear operation,
+ # thus the two projection matrices for key can be merged
+ # so I just leave it as is in order to not re-train all models :)
+ query = self.q_proj(query) # [B, H*W, C]
+ key = self.k_proj(query) # [B, H*W, C]
+
+ value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
+
+ scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
+ prob = torch.softmax(scores, dim=-1)
+
+ out = torch.matmul(prob, value) # [B, H*W, 2]
+ out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
+
+ return out
+
+ def forward_local_window_attn(self, feature0, flow,
+ local_window_radius=1,
+ ):
+ assert flow.size(1) == 2
+ assert local_window_radius > 0
+
+ b, c, h, w = feature0.size()
+
+ feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
+ ).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
+
+ kernel_size = 2 * local_window_radius + 1
+
+ feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
+
+ feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
+ padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
+
+ feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
+ 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
+
+ flow_window = F.unfold(flow, kernel_size=kernel_size,
+ padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
+
+ flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute(
+ 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2]
+
+ scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
+
+ prob = torch.softmax(scores, dim=-1)
+
+ out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
+
+ return out
diff --git a/basicsr/archs/gmflow/gmflow/trident_conv.py b/basicsr/archs/gmflow/gmflow/trident_conv.py
new file mode 100755
index 0000000000000000000000000000000000000000..445663c2d1065e10899f728ad2628e313f218024
--- /dev/null
+++ b/basicsr/archs/gmflow/gmflow/trident_conv.py
@@ -0,0 +1,90 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair
+
+
+class MultiScaleTridentConv(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ strides=1,
+ paddings=0,
+ dilations=1,
+ dilation=1,
+ groups=1,
+ num_branch=1,
+ test_branch_idx=-1,
+ bias=False,
+ norm=None,
+ activation=None,
+ ):
+ super(MultiScaleTridentConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.num_branch = num_branch
+ self.stride = _pair(stride)
+ self.groups = groups
+ self.with_bias = bias
+ self.dilation = dilation
+ if isinstance(paddings, int):
+ paddings = [paddings] * self.num_branch
+ if isinstance(dilations, int):
+ dilations = [dilations] * self.num_branch
+ if isinstance(strides, int):
+ strides = [strides] * self.num_branch
+ self.paddings = [_pair(padding) for padding in paddings]
+ self.dilations = [_pair(dilation) for dilation in dilations]
+ self.strides = [_pair(stride) for stride in strides]
+ self.test_branch_idx = test_branch_idx
+ self.norm = norm
+ self.activation = activation
+
+ assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
+ )
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.bias = None
+
+ nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
+ if self.bias is not None:
+ nn.init.constant_(self.bias, 0)
+
+ def forward(self, inputs):
+ num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
+ assert len(inputs) == num_branch
+
+ if self.training or self.test_branch_idx == -1:
+ outputs = [
+ F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
+ for input, stride, padding in zip(inputs, self.strides, self.paddings)
+ ]
+ else:
+ outputs = [
+ F.conv2d(
+ inputs[0],
+ self.weight,
+ self.bias,
+ self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
+ self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
+ self.dilation,
+ self.groups,
+ )
+ ]
+
+ if self.norm is not None:
+ outputs = [self.norm(x) for x in outputs]
+ if self.activation is not None:
+ outputs = [self.activation(x) for x in outputs]
+ return outputs
diff --git a/basicsr/archs/gmflow/gmflow/utils.py b/basicsr/archs/gmflow/gmflow/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..4a0d27eeb43b0d0601b95c16aea651620c1250dc
--- /dev/null
+++ b/basicsr/archs/gmflow/gmflow/utils.py
@@ -0,0 +1,86 @@
+import torch
+from .position import PositionEmbeddingSine
+
+
+def split_feature(feature,
+ num_splits=2,
+ channel_last=False,
+ ):
+ if channel_last: # [B, H, W, C]
+ b, h, w, c = feature.size()
+ assert h % num_splits == 0 and w % num_splits == 0
+
+ b_new = b * num_splits * num_splits
+ h_new = h // num_splits
+ w_new = w // num_splits
+
+ feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
+ ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
+ else: # [B, C, H, W]
+ b, c, h, w = feature.size()
+ assert h % num_splits == 0 and w % num_splits == 0
+
+ b_new = b * num_splits * num_splits
+ h_new = h // num_splits
+ w_new = w // num_splits
+
+ feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
+ ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
+
+ return feature
+
+
+def merge_splits(splits,
+ num_splits=2,
+ channel_last=False,
+ ):
+ if channel_last: # [B*K*K, H/K, W/K, C]
+ b, h, w, c = splits.size()
+ new_b = b // num_splits // num_splits
+
+ splits = splits.view(new_b, num_splits, num_splits, h, w, c)
+ merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
+ new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
+ else: # [B*K*K, C, H/K, W/K]
+ b, c, h, w = splits.size()
+ new_b = b // num_splits // num_splits
+
+ splits = splits.view(new_b, num_splits, num_splits, c, h, w)
+ merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
+ new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
+
+ return merge
+
+
+def normalize_img(img0, img1):
+ # loaded images are in [0, 255]
+ # normalize by ImageNet mean and std
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
+ img0 = (img0 / 255. - mean) / std
+ img1 = (img1 / 255. - mean) / std
+
+ return img0, img1
+
+
+def feature_add_position(feature0, feature1, attn_splits, feature_channels):
+ pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
+
+ if attn_splits > 1: # add position in splited window
+ feature0_splits = split_feature(feature0, num_splits=attn_splits)
+ feature1_splits = split_feature(feature1, num_splits=attn_splits)
+
+ position = pos_enc(feature0_splits)
+
+ feature0_splits = feature0_splits + position
+ feature1_splits = feature1_splits + position
+
+ feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
+ feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
+ else:
+ position = pos_enc(feature0)
+
+ feature0 = feature0 + position
+ feature1 = feature1 + position
+
+ return feature0, feature1
diff --git a/basicsr/archs/gmflow/loss.py b/basicsr/archs/gmflow/loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..f9f0b0216b2a277f422daab92d5c5b4f53458ae3
--- /dev/null
+++ b/basicsr/archs/gmflow/loss.py
@@ -0,0 +1,37 @@
+import torch
+
+
+def flow_loss_func(flow_preds, flow_gt, valid,
+ gamma=0.9,
+ max_flow=400,
+ **kwargs,
+ ):
+ n_predictions = len(flow_preds)
+ flow_loss = 0.0
+
+ # exlude invalid pixels and extremely large diplacements
+ mag = torch.sum(flow_gt ** 2, dim=1).sqrt() # [B, H, W]
+ valid = (valid >= 0.5) & (mag < max_flow)
+
+ for i in range(n_predictions):
+ i_weight = gamma ** (n_predictions - i - 1)
+
+ i_loss = (flow_preds[i] - flow_gt).abs()
+
+ flow_loss += i_weight * (valid[:, None] * i_loss).mean()
+
+ epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
+
+ if valid.max() < 0.5:
+ pass
+
+ epe = epe.view(-1)[valid.view(-1)]
+
+ metrics = {
+ 'epe': epe.mean().item(),
+ '1px': (epe > 1).float().mean().item(),
+ '3px': (epe > 3).float().mean().item(),
+ '5px': (epe > 5).float().mean().item(),
+ }
+
+ return flow_loss, metrics
diff --git a/basicsr/archs/gmflow/main.py b/basicsr/archs/gmflow/main.py
new file mode 100755
index 0000000000000000000000000000000000000000..281b402e2a9032dd992dd8dc126e9bd897d86f3d
--- /dev/null
+++ b/basicsr/archs/gmflow/main.py
@@ -0,0 +1,557 @@
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+import argparse
+import numpy as np
+import os
+
+from data import build_train_dataset
+from gmflow.gmflow import GMFlow
+from loss import flow_loss_func
+from evaluate import (validate_chairs, validate_things, validate_sintel, validate_kitti,
+ create_sintel_submission, create_kitti_submission, inference_on_dir)
+
+from utils.logger import Logger
+from utils import misc
+from utils.dist_utils import get_dist_info, init_dist, setup_for_distributed
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser()
+
+ # dataset
+ parser.add_argument('--checkpoint_dir', default='tmp', type=str,
+ help='where to save the training log and models')
+ parser.add_argument('--stage', default='chairs', type=str,
+ help='training stage')
+ parser.add_argument('--image_size', default=[384, 512], type=int, nargs='+',
+ help='image size for training')
+ parser.add_argument('--padding_factor', default=16, type=int,
+ help='the input should be divisible by padding_factor, otherwise do padding')
+
+ parser.add_argument('--max_flow', default=400, type=int,
+ help='exclude very large motions during training')
+ parser.add_argument('--val_dataset', default=['chairs'], type=str, nargs='+',
+ help='validation dataset')
+ parser.add_argument('--with_speed_metric', action='store_true',
+ help='with speed metric when evaluation')
+
+ # training
+ parser.add_argument('--lr', default=4e-4, type=float)
+ parser.add_argument('--batch_size', default=12, type=int)
+ parser.add_argument('--num_workers', default=4, type=int)
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
+ parser.add_argument('--grad_clip', default=1.0, type=float)
+ parser.add_argument('--num_steps', default=100000, type=int)
+ parser.add_argument('--seed', default=326, type=int)
+ parser.add_argument('--summary_freq', default=100, type=int)
+ parser.add_argument('--val_freq', default=10000, type=int)
+ parser.add_argument('--save_ckpt_freq', default=10000, type=int)
+ parser.add_argument('--save_latest_ckpt_freq', default=1000, type=int)
+
+ # resume pretrained model or resume training
+ parser.add_argument('--resume', default=None, type=str,
+ help='resume from pretrain model for finetuing or resume from terminated training')
+ parser.add_argument('--strict_resume', action='store_true')
+ parser.add_argument('--no_resume_optimizer', action='store_true')
+
+ # GMFlow model
+ parser.add_argument('--num_scales', default=1, type=int,
+ help='basic gmflow model uses a single 1/8 feature, the refinement uses 1/4 feature')
+ parser.add_argument('--feature_channels', default=128, type=int)
+ parser.add_argument('--upsample_factor', default=8, type=int)
+ parser.add_argument('--num_transformer_layers', default=6, type=int)
+ parser.add_argument('--num_head', default=1, type=int)
+ parser.add_argument('--attention_type', default='swin', type=str)
+ parser.add_argument('--ffn_dim_expansion', default=4, type=int)
+
+ parser.add_argument('--attn_splits_list', default=[2], type=int, nargs='+',
+ help='number of splits in attention')
+ parser.add_argument('--corr_radius_list', default=[-1], type=int, nargs='+',
+ help='correlation radius for matching, -1 indicates global matching')
+ parser.add_argument('--prop_radius_list', default=[-1], type=int, nargs='+',
+ help='self-attention radius for flow propagation, -1 indicates global attention')
+
+ # loss
+ parser.add_argument('--gamma', default=0.9, type=float,
+ help='loss weight')
+
+ # evaluation
+ parser.add_argument('--eval', action='store_true')
+ parser.add_argument('--save_eval_to_file', action='store_true')
+ parser.add_argument('--evaluate_matched_unmatched', action='store_true')
+
+ # inference on a directory
+ parser.add_argument('--inference_dir', default=None, type=str)
+ parser.add_argument('--inference_size', default=None, type=int, nargs='+',
+ help='can specify the inference size')
+ parser.add_argument('--dir_paired_data', action='store_true',
+ help='Paired data in a dir instead of a sequence')
+ parser.add_argument('--save_flo_flow', action='store_true')
+ parser.add_argument('--pred_bidir_flow', action='store_true',
+ help='predict bidirectional flow')
+ parser.add_argument('--fwd_bwd_consistency_check', action='store_true',
+ help='forward backward consistency check with bidirection flow')
+
+ # predict on sintel and kitti test set for submission
+ parser.add_argument('--submission', action='store_true',
+ help='submission to sintel or kitti test sets')
+ parser.add_argument('--output_path', default='output', type=str,
+ help='where to save the prediction results')
+ parser.add_argument('--save_vis_flow', action='store_true',
+ help='visualize flow prediction as .png image')
+ parser.add_argument('--no_save_flo', action='store_true',
+ help='not save flow as .flo')
+
+ # distributed training
+ parser.add_argument('--local_rank', default=0, type=int)
+ parser.add_argument('--distributed', action='store_true')
+ parser.add_argument('--launcher', default='none', type=str, choices=['none', 'pytorch'])
+ parser.add_argument('--gpu_ids', default=0, type=int, nargs='+')
+
+ parser.add_argument('--count_time', action='store_true',
+ help='measure the inference time on sintel')
+
+ return parser
+
+
+def main(args):
+ if not args.eval and not args.submission and args.inference_dir is None:
+ if args.local_rank == 0:
+ print('pytorch version:', torch.__version__)
+ print(args)
+ misc.save_args(args)
+ misc.check_path(args.checkpoint_dir)
+ misc.save_command(args.checkpoint_dir)
+
+ seed = args.seed
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ torch.backends.cudnn.benchmark = True
+
+ if args.launcher == 'none':
+ args.distributed = False
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ else:
+ args.distributed = True
+
+ # adjust batch size for each gpu
+ assert args.batch_size % torch.cuda.device_count() == 0
+ args.batch_size = args.batch_size // torch.cuda.device_count()
+
+ dist_params = dict(backend='nccl')
+ init_dist(args.launcher, **dist_params)
+ # re-set gpu_ids with distributed training mode
+ _, world_size = get_dist_info()
+ args.gpu_ids = range(world_size)
+ device = torch.device('cuda:{}'.format(args.local_rank))
+
+ setup_for_distributed(args.local_rank == 0)
+
+ # model
+ model = GMFlow(feature_channels=args.feature_channels,
+ num_scales=args.num_scales,
+ upsample_factor=args.upsample_factor,
+ num_head=args.num_head,
+ attention_type=args.attention_type,
+ ffn_dim_expansion=args.ffn_dim_expansion,
+ num_transformer_layers=args.num_transformer_layers,
+ ).to(device)
+
+ if not args.eval and not args.submission and not args.inference_dir:
+ print('Model definition:')
+ print(model)
+
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model.to(device),
+ device_ids=[args.local_rank],
+ output_device=args.local_rank)
+ model_without_ddp = model.module
+ else:
+ if torch.cuda.device_count() > 1:
+ print('Use %d GPUs' % torch.cuda.device_count())
+ model = torch.nn.DataParallel(model)
+
+ model_without_ddp = model.module
+ else:
+ model_without_ddp = model
+
+ num_params = sum(p.numel() for p in model.parameters())
+ print('Number of params:', num_params)
+ if not args.eval and not args.submission and args.inference_dir is None:
+ save_name = '%d_parameters' % num_params
+ open(os.path.join(args.checkpoint_dir, save_name), 'a').close()
+
+ optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr,
+ weight_decay=args.weight_decay)
+
+ start_epoch = 0
+ start_step = 0
+ # resume checkpoints
+ if args.resume:
+ print('Load checkpoint: %s' % args.resume)
+
+ loc = 'cuda:{}'.format(args.local_rank)
+ checkpoint = torch.load(args.resume, map_location=loc)
+
+ weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
+
+ model_without_ddp.load_state_dict(weights, strict=args.strict_resume)
+
+ if 'optimizer' in checkpoint and 'step' in checkpoint and 'epoch' in checkpoint and not \
+ args.no_resume_optimizer:
+ print('Load optimizer')
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ start_epoch = checkpoint['epoch']
+ start_step = checkpoint['step']
+
+ print('start_epoch: %d, start_step: %d' % (start_epoch, start_step))
+
+ # evaluate
+ if args.eval:
+ val_results = {}
+
+ if 'chairs' in args.val_dataset:
+ results_dict = validate_chairs(model_without_ddp,
+ with_speed_metric=args.with_speed_metric,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+
+ val_results.update(results_dict)
+
+ if 'things' in args.val_dataset:
+ results_dict = validate_things(model_without_ddp,
+ padding_factor=args.padding_factor,
+ with_speed_metric=args.with_speed_metric,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ val_results.update(results_dict)
+
+ if 'sintel' in args.val_dataset:
+ results_dict = validate_sintel(model_without_ddp,
+ count_time=args.count_time,
+ padding_factor=args.padding_factor,
+ with_speed_metric=args.with_speed_metric,
+ evaluate_matched_unmatched=args.evaluate_matched_unmatched,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ val_results.update(results_dict)
+
+ if 'kitti' in args.val_dataset:
+ results_dict = validate_kitti(model_without_ddp,
+ padding_factor=args.padding_factor,
+ with_speed_metric=args.with_speed_metric,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ val_results.update(results_dict)
+
+ if args.save_eval_to_file:
+ misc.check_path(args.checkpoint_dir)
+ val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')
+ with open(val_file, 'a') as f:
+ f.write('\neval results after training done\n\n')
+ metrics = ['chairs_epe', 'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+',
+ 'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40', 'things_clean_s40+',
+ 'things_final_epe', 'things_final_s0_10', 'things_final_s10_40', 'things_final_s40+',
+ 'sintel_clean_epe', 'sintel_clean_s0_10', 'sintel_clean_s10_40', 'sintel_clean_s40+',
+ 'sintel_final_epe', 'sintel_final_s0_10', 'sintel_final_s10_40', 'sintel_final_s40+',
+ 'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+',
+ ]
+ eval_metrics = []
+ for metric in metrics:
+ if metric in val_results.keys():
+ eval_metrics.append(metric)
+
+ metrics_values = [val_results[metric] for metric in eval_metrics]
+
+ num_metrics = len(eval_metrics)
+
+ # save as markdown format
+ f.write(("| {:>20} " * num_metrics + '\n').format(*eval_metrics))
+ f.write(("| {:20.3f} " * num_metrics).format(*metrics_values))
+
+ f.write('\n\n')
+
+ return
+
+ # Sintel and KITTI submission
+ if args.submission:
+ # NOTE: args.val_dataset is a list
+ if args.val_dataset[0] == 'sintel':
+ create_sintel_submission(model_without_ddp,
+ output_path=args.output_path,
+ padding_factor=args.padding_factor,
+ save_vis_flow=args.save_vis_flow,
+ no_save_flo=args.no_save_flo,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ elif args.val_dataset[0] == 'kitti':
+ create_kitti_submission(model_without_ddp,
+ output_path=args.output_path,
+ padding_factor=args.padding_factor,
+ save_vis_flow=args.save_vis_flow,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ else:
+ raise ValueError(f'Not supported dataset for submission')
+
+ return
+
+ # inferece on a dir
+ if args.inference_dir is not None:
+ inference_on_dir(model_without_ddp,
+ inference_dir=args.inference_dir,
+ output_path=args.output_path,
+ padding_factor=args.padding_factor,
+ inference_size=args.inference_size,
+ paired_data=args.dir_paired_data,
+ save_flo_flow=args.save_flo_flow,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ pred_bidir_flow=args.pred_bidir_flow,
+ fwd_bwd_consistency_check=args.fwd_bwd_consistency_check,
+ )
+
+ return
+
+ # training datset
+ train_dataset = build_train_dataset(args)
+ print('Number of training images:', len(train_dataset))
+
+ # Multi-processing
+ if args.distributed:
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
+ train_dataset,
+ num_replicas=torch.cuda.device_count(),
+ rank=args.local_rank)
+ else:
+ train_sampler = None
+
+ shuffle = False if args.distributed else True
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
+ shuffle=shuffle, num_workers=args.num_workers,
+ pin_memory=True, drop_last=True,
+ sampler=train_sampler)
+
+ last_epoch = start_step if args.resume and start_step > 0 else -1
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ optimizer, args.lr,
+ args.num_steps + 10,
+ pct_start=0.05,
+ cycle_momentum=False,
+ anneal_strategy='cos',
+ last_epoch=last_epoch,
+ )
+
+ if args.local_rank == 0:
+ summary_writer = SummaryWriter(args.checkpoint_dir)
+ logger = Logger(lr_scheduler, summary_writer, args.summary_freq,
+ start_step=start_step)
+
+ total_steps = start_step
+ epoch = start_epoch
+ print('Start training')
+
+ while total_steps < args.num_steps:
+ model.train()
+
+ # mannual change random seed for shuffling every epoch
+ if args.distributed:
+ train_sampler.set_epoch(epoch)
+
+ for i, sample in enumerate(train_loader):
+ img1, img2, flow_gt, valid = [x.to(device) for x in sample]
+
+ results_dict = model(img1, img2,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+
+ flow_preds = results_dict['flow_preds']
+
+ loss, metrics = flow_loss_func(flow_preds, flow_gt, valid,
+ gamma=args.gamma,
+ max_flow=args.max_flow,
+ )
+
+ if isinstance(loss, float):
+ continue
+
+ if torch.isnan(loss):
+ continue
+
+ metrics.update({'total_loss': loss.item()})
+
+ # more efficient zero_grad
+ for param in model_without_ddp.parameters():
+ param.grad = None
+
+ loss.backward()
+
+ # Gradient clipping
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
+
+ optimizer.step()
+
+ lr_scheduler.step()
+
+ if args.local_rank == 0:
+ logger.push(metrics)
+
+ logger.add_image_summary(img1, img2, flow_preds, flow_gt)
+
+ total_steps += 1
+
+ if total_steps % args.save_ckpt_freq == 0 or total_steps == args.num_steps:
+ if args.local_rank == 0:
+ checkpoint_path = os.path.join(args.checkpoint_dir, 'step_%06d.pth' % total_steps)
+ torch.save({
+ 'model': model_without_ddp.state_dict()
+ }, checkpoint_path)
+
+ if total_steps % args.save_latest_ckpt_freq == 0:
+ checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth')
+
+ if args.local_rank == 0:
+ torch.save({
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'step': total_steps,
+ 'epoch': epoch,
+ }, checkpoint_path)
+
+ if total_steps % args.val_freq == 0:
+ print('Start validation')
+
+ val_results = {}
+ # support validation on multiple datasets
+ if 'chairs' in args.val_dataset:
+ results_dict = validate_chairs(model_without_ddp,
+ with_speed_metric=args.with_speed_metric,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ if args.local_rank == 0:
+ val_results.update(results_dict)
+
+ if 'things' in args.val_dataset:
+ results_dict = validate_things(model_without_ddp,
+ padding_factor=args.padding_factor,
+ with_speed_metric=args.with_speed_metric,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ if args.local_rank == 0:
+ val_results.update(results_dict)
+
+ if 'sintel' in args.val_dataset:
+ results_dict = validate_sintel(model_without_ddp,
+ count_time=args.count_time,
+ padding_factor=args.padding_factor,
+ with_speed_metric=args.with_speed_metric,
+ evaluate_matched_unmatched=args.evaluate_matched_unmatched,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ if args.local_rank == 0:
+ val_results.update(results_dict)
+
+ if 'kitti' in args.val_dataset:
+ results_dict = validate_kitti(model_without_ddp,
+ padding_factor=args.padding_factor,
+ with_speed_metric=args.with_speed_metric,
+ attn_splits_list=args.attn_splits_list,
+ corr_radius_list=args.corr_radius_list,
+ prop_radius_list=args.prop_radius_list,
+ )
+ if args.local_rank == 0:
+ val_results.update(results_dict)
+
+ if args.local_rank == 0:
+ logger.write_dict(val_results)
+
+ # Save validation results
+ val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')
+ with open(val_file, 'a') as f:
+ f.write('step: %06d\n' % total_steps)
+ if args.evaluate_matched_unmatched:
+ metrics = ['chairs_epe',
+ 'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+',
+ 'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40',
+ 'things_clean_s40+',
+ 'sintel_clean_epe', 'sintel_clean_matched', 'sintel_clean_unmatched',
+ 'sintel_clean_s0_10', 'sintel_clean_s10_40',
+ 'sintel_clean_s40+',
+ 'sintel_final_epe', 'sintel_final_matched', 'sintel_final_unmatched',
+ 'sintel_final_s0_10', 'sintel_final_s10_40',
+ 'sintel_final_s40+',
+ 'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+',
+ ]
+ else:
+ metrics = ['chairs_epe', 'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+',
+ 'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40',
+ 'things_clean_s40+',
+ 'sintel_clean_epe', 'sintel_clean_s0_10', 'sintel_clean_s10_40',
+ 'sintel_clean_s40+',
+ 'sintel_final_epe', 'sintel_final_s0_10', 'sintel_final_s10_40',
+ 'sintel_final_s40+',
+ 'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+',
+ ]
+
+ eval_metrics = []
+ for metric in metrics:
+ if metric in val_results.keys():
+ eval_metrics.append(metric)
+
+ metrics_values = [val_results[metric] for metric in eval_metrics]
+
+ num_metrics = len(eval_metrics)
+
+ # save as markdown format
+ if args.evaluate_matched_unmatched:
+ f.write(("| {:>25} " * num_metrics + '\n').format(*eval_metrics))
+ f.write(("| {:25.3f} " * num_metrics).format(*metrics_values))
+ else:
+ f.write(("| {:>20} " * num_metrics + '\n').format(*eval_metrics))
+ f.write(("| {:20.3f} " * num_metrics).format(*metrics_values))
+
+ f.write('\n\n')
+
+ model.train()
+
+ if total_steps >= args.num_steps:
+ print('Training done')
+
+ return
+
+ epoch += 1
+
+
+if __name__ == '__main__':
+ parser = get_args_parser()
+ args = parser.parse_args()
+
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ main(args)
diff --git a/basicsr/archs/gmflow/scripts/demo.sh b/basicsr/archs/gmflow/scripts/demo.sh
new file mode 100755
index 0000000000000000000000000000000000000000..3aa5d2675781286d81512446af5865ff222491be
--- /dev/null
+++ b/basicsr/archs/gmflow/scripts/demo.sh
@@ -0,0 +1,63 @@
+#!/usr/bin/env bash
+
+# inference GMFlow without refinement
+
+# sintel
+
+# only predict forward flow
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--inference_dir demo/sintel_market_1 \
+--output_path output/gmflow-norefine-sintel_market_1 \
+--resume pretrained/gmflow_sintel-0c07dcb3.pth
+
+# predict forward & backward flow
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--inference_dir demo/sintel_market_1 \
+--output_path output/gmflow-norefine-sintel_market_1 \
+--pred_bidir_flow \
+--resume pretrained/gmflow_sintel-0c07dcb3.pth
+
+
+# predict forward & backward flow with forward-backward consistency check
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--inference_dir demo/sintel_market_1 \
+--output_path output/gmflow-norefine-sintel_market_1 \
+--pred_bidir_flow \
+--fwd_bwd_consistency_check \
+--resume pretrained/gmflow_sintel-0c07dcb3.pth
+
+
+# davis
+
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--inference_dir demo/davis_breakdance-flare \
+--output_path output/gmflow-norefine-davis_breakdance-flare \
+--resume pretrained/gmflow_sintel-0c07dcb3.pth
+
+
+
+
+# inference GMFlow with refinement
+
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--inference_dir demo/davis_breakdance-flare \
+--output_path output/gmflow-withrefine-davis_breakdance-flare \
+--resume pretrained/gmflow_with_refine_sintel-3ed1cf48.pth \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1
+
+
+
+
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--inference_dir demo/sintel_test_clean_market_1 \
+--output_path output/gmflow-norefine-sintel_test_clean_market_1 \
+--pred_bidir_flow \
+--fwd_bwd_consistency_check \
+--resume pretrained/gmflow_sintel-0c07dcb3.pth
+
+
diff --git a/basicsr/archs/gmflow/scripts/evaluate.sh b/basicsr/archs/gmflow/scripts/evaluate.sh
new file mode 100755
index 0000000000000000000000000000000000000000..fa6dbefeddd2292a7fe5bfc277501080ccdd007a
--- /dev/null
+++ b/basicsr/archs/gmflow/scripts/evaluate.sh
@@ -0,0 +1,83 @@
+#!/usr/bin/env bash
+
+# evaluate GMFlow without refinement
+
+# evaluate chairs & things trained model on things and sintel (Table 3 of GMFlow paper)
+# the output should be:
+# Number of validation image pairs: 1024
+# Validation Things test set (things_clean) EPE: 3.475
+# Validation Things test (things_clean) s0_10: 0.666, s10_40: 1.310, s40+: 8.968
+# Number of validation image pairs: 1041
+# Validation Sintel (clean) EPE: 1.495, 1px: 0.161, 3px: 0.059, 5px: 0.040
+# Validation Sintel (clean) s0_10: 0.457, s10_40: 1.770, s40+: 8.257
+# Number of validation image pairs: 1041
+# Validation Sintel (final) EPE: 2.955, 1px: 0.209, 3px: 0.098, 5px: 0.071
+# Validation Sintel (final) s0_10: 0.725, s10_40: 3.446, s40+: 17.701
+
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--eval \
+--resume pretrained/gmflow_things-e9887eda.pth \
+--val_dataset things sintel \
+--with_speed_metric
+
+
+
+# evaluate GMFlow with refinement
+
+# evaluate chairs & things trained model on things and sintel (Table 3 of GMFlow paper)
+# the output should be:
+# Validation Things test set (things_clean) EPE: 2.804
+# Validation Things test (things_clean) s0_10: 0.527, s10_40: 1.009, s40+: 7.314
+# Number of validation image pairs: 1041
+# Validation Sintel (clean) EPE: 1.084, 1px: 0.092, 3px: 0.040, 5px: 0.028
+# Validation Sintel (clean) s0_10: 0.303, s10_40: 1.252, s40+: 6.261
+# Number of validation image pairs: 1041
+# Validation Sintel (final) EPE: 2.475, 1px: 0.147, 3px: 0.077, 5px: 0.058
+# Validation Sintel (final) s0_10: 0.511, s10_40: 2.810, s40+: 15.669
+
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--eval \
+--resume pretrained/gmflow_with_refine_things-36579974.pth \
+--val_dataset things sintel \
+--with_speed_metric \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1
+
+
+
+# evaluate matched & matched on sintel
+
+# evaluate GMFlow without refinement
+
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--eval \
+--evaluate_matched_unmatched \
+--resume pretrained/gmflow_things-e9887eda.pth \
+--val_dataset sintel
+
+# evaluate GMFlow with refinement
+
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--eval \
+--evaluate_matched_unmatched \
+--resume pretrained/gmflow_with_refine_things-36579974.pth \
+--val_dataset sintel \
+--with_speed_metric \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1
+
+
+
+
+
+
+
+
diff --git a/basicsr/archs/gmflow/scripts/submission.sh b/basicsr/archs/gmflow/scripts/submission.sh
new file mode 100755
index 0000000000000000000000000000000000000000..c19223eafc3bb379a528cb16c7ff19f467a1c17a
--- /dev/null
+++ b/basicsr/archs/gmflow/scripts/submission.sh
@@ -0,0 +1,67 @@
+#!/usr/bin/env bash
+
+
+# generate prediction results for submission on sintel and kitti online servers
+
+
+# GMFlow without refinement
+
+# submission to sintel
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--submission \
+--output_path submission/sintel-gmflow-norefine \
+--val_dataset sintel \
+--resume pretrained/gmflow_sintel-0c07dcb3.pth
+
+# submission to kitti
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--submission \
+--output_path submission/kitti-gmflow-norefine \
+--val_dataset kitti \
+--resume pretrained/gmflow_kitti-285701a8.pth
+
+
+# you can also visualize the predictions before submission
+# CUDA_VISIBLE_DEVICES=0 python main.py \
+# --submission \
+# --output_path submission/sintel-gmflow-norefine-vis \
+# --save_vis_flow \
+# --no_save_flo \
+# --val_dataset sintel \
+# --resume pretrained/gmflow_sintel.pth
+
+
+
+
+# GMFlow with refinement
+
+# submission to sintel
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--submission \
+--output_path submission/sintel-gmflow-withrefine \
+--val_dataset sintel \
+--resume pretrained/gmflow_with_refine_sintel-3ed1cf48.pth \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1
+
+# submission to kitti
+CUDA_VISIBLE_DEVICES=0 python main.py \
+--submission \
+--output_path submission/kitti-gmflow-withrefine \
+--val_dataset kitti \
+--resume pretrained/gmflow_with_refine_kitti-8d3b9786.pth \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1
+
+
+
+
+
diff --git a/basicsr/archs/gmflow/scripts/train_gmflow.sh b/basicsr/archs/gmflow/scripts/train_gmflow.sh
new file mode 100755
index 0000000000000000000000000000000000000000..a04fc583393c8ea35cb9298a1f814a75052f2a96
--- /dev/null
+++ b/basicsr/archs/gmflow/scripts/train_gmflow.sh
@@ -0,0 +1,108 @@
+#!/usr/bin/env bash
+
+# GMFlow without refinement
+
+# number of gpus for training, please set according to your hardware
+# by default use all gpus on a machine
+# can be trained on 4x 16GB V100 or 2x 32GB V100 or 2x 40GB A100 gpus
+NUM_GPUS=4
+
+# chairs
+CHECKPOINT_DIR=checkpoints/chairs-gmflow && \
+mkdir -p ${CHECKPOINT_DIR} && \
+python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+--launcher pytorch \
+--checkpoint_dir ${CHECKPOINT_DIR} \
+--batch_size 16 \
+--val_dataset chairs sintel kitti \
+--lr 4e-4 \
+--image_size 384 512 \
+--padding_factor 16 \
+--upsample_factor 8 \
+--with_speed_metric \
+--val_freq 10000 \
+--save_ckpt_freq 10000 \
+--num_steps 100000 \
+2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
+# things (our final model is trained for 800K iterations, for ablation study, you can train for 200K)
+CHECKPOINT_DIR=checkpoints/things-gmflow && \
+mkdir -p ${CHECKPOINT_DIR} && \
+python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+--launcher pytorch \
+--checkpoint_dir ${CHECKPOINT_DIR} \
+--resume checkpoints/chairs-gmflow/step_100000.pth \
+--stage things \
+--batch_size 8 \
+--val_dataset things sintel kitti \
+--lr 2e-4 \
+--image_size 384 768 \
+--padding_factor 16 \
+--upsample_factor 8 \
+--with_speed_metric \
+--val_freq 40000 \
+--save_ckpt_freq 50000 \
+--num_steps 800000 \
+2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
+# sintel
+CHECKPOINT_DIR=checkpoints/sintel-gmflow && \
+mkdir -p ${CHECKPOINT_DIR} && \
+python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+--launcher pytorch \
+--checkpoint_dir ${CHECKPOINT_DIR} \
+--resume checkpoints/things-gmflow/step_800000.pth \
+--stage sintel \
+--batch_size 8 \
+--val_dataset sintel kitti \
+--lr 2e-4 \
+--image_size 320 896 \
+--padding_factor 16 \
+--upsample_factor 8 \
+--with_speed_metric \
+--val_freq 20000 \
+--save_ckpt_freq 20000 \
+--num_steps 200000 \
+2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
+# kitti
+CHECKPOINT_DIR=checkpoints/kitti-gmflow && \
+mkdir -p ${CHECKPOINT_DIR} && \
+python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+--launcher pytorch \
+--checkpoint_dir ${CHECKPOINT_DIR} \
+--resume checkpoints/sintel-gmflow/step_200000.pth \
+--stage kitti \
+--batch_size 8 \
+--val_dataset kitti \
+--lr 2e-4 \
+--image_size 320 1152 \
+--padding_factor 16 \
+--upsample_factor 8 \
+--with_speed_metric \
+--val_freq 10000 \
+--save_ckpt_freq 10000 \
+--num_steps 100000 \
+2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
+
+# a final note: if your training is terminated unexpectedly, you can resume from the latest checkpoint
+# an example: resume chairs training
+# CHECKPOINT_DIR=checkpoints/chairs-gmflow && \
+# mkdir -p ${CHECKPOINT_DIR} && \
+# python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+# --launcher pytorch \
+# --checkpoint_dir ${CHECKPOINT_DIR} \
+# --resume checkpoints/chairs-gmflow/checkpoint_latest.pth \
+# --batch_size 16 \
+# --val_dataset chairs sintel kitti \
+# --lr 4e-4 \
+# --image_size 384 512 \
+# --padding_factor 16 \
+# --upsample_factor 8 \
+# --with_speed_metric \
+# --val_freq 10000 \
+# --save_ckpt_freq 10000 \
+# --num_steps 100000 \
+# 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
diff --git a/basicsr/archs/gmflow/scripts/train_gmflow_with_refine.sh b/basicsr/archs/gmflow/scripts/train_gmflow_with_refine.sh
new file mode 100755
index 0000000000000000000000000000000000000000..db8ed3d423fd2993cb3c25a171fc52abe2c4c792
--- /dev/null
+++ b/basicsr/archs/gmflow/scripts/train_gmflow_with_refine.sh
@@ -0,0 +1,128 @@
+#!/usr/bin/env bash
+
+# GMFlow with refinement
+
+# number of gpus for training, please set according to your hardware
+# by default use all gpus on a machine
+# can be trained on 4x 32G V100 or 4x 40GB A100 or 8x 16G V100 gpus
+NUM_GPUS=4
+
+# chairs
+CHECKPOINT_DIR=checkpoints/chairs-gmflow_with_refine && \
+mkdir -p ${CHECKPOINT_DIR} && \
+python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+--launcher pytorch \
+--checkpoint_dir ${CHECKPOINT_DIR} \
+--batch_size 16 \
+--val_dataset chairs sintel kitti \
+--lr 4e-4 \
+--image_size 384 512 \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1 \
+--with_speed_metric \
+--val_freq 10000 \
+--save_ckpt_freq 10000 \
+--num_steps 100000 \
+2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
+# things (our final model is trained for 800K iterations, for ablation study, you can train for 200K)
+CHECKPOINT_DIR=checkpoints/things-gmflow_with_refine && \
+mkdir -p ${CHECKPOINT_DIR} && \
+python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+--launcher pytorch \
+--checkpoint_dir ${CHECKPOINT_DIR} \
+--resume checkpoints/chairs-gmflow_with_refine/step_100000.pth \
+--stage things \
+--batch_size 8 \
+--val_dataset things sintel kitti \
+--lr 2e-4 \
+--image_size 384 768 \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1 \
+--with_speed_metric \
+--val_freq 40000 \
+--save_ckpt_freq 50000 \
+--num_steps 800000 \
+2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
+# sintel
+CHECKPOINT_DIR=checkpoints/sintel-gmflow_with_refine && \
+mkdir -p ${CHECKPOINT_DIR} && \
+python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+--launcher pytorch \
+--checkpoint_dir ${CHECKPOINT_DIR} \
+--resume checkpoints/things-gmflow_with_refine/step_800000.pth \
+--stage sintel \
+--batch_size 8 \
+--val_dataset sintel kitti \
+--lr 2e-4 \
+--image_size 320 896 \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1 \
+--with_speed_metric \
+--val_freq 20000 \
+--save_ckpt_freq 20000 \
+--num_steps 200000 \
+2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
+# kitti
+CHECKPOINT_DIR=checkpoints/kitti-gmflow_with_refine && \
+mkdir -p ${CHECKPOINT_DIR} && \
+python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+--launcher pytorch \
+--checkpoint_dir ${CHECKPOINT_DIR} \
+--resume checkpoints/sintel-gmflow_with_refine/step_200000.pth \
+--stage kitti \
+--batch_size 8 \
+--val_dataset kitti \
+--lr 2e-4 \
+--image_size 320 1152 \
+--padding_factor 32 \
+--upsample_factor 4 \
+--num_scales 2 \
+--attn_splits_list 2 8 \
+--corr_radius_list -1 4 \
+--prop_radius_list -1 1 \
+--with_speed_metric \
+--val_freq 10000 \
+--save_ckpt_freq 10000 \
+--num_steps 100000 \
+2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
+
+
+
+# a final note: if your training is terminated unexpectedly, you can resume from the latest checkpoint
+# an example: resume chairs training
+# CHECKPOINT_DIR=checkpoints/chairs-gmflow_with_refine && \
+# mkdir -p ${CHECKPOINT_DIR} && \
+# python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \
+# --launcher pytorch \
+# --checkpoint_dir ${CHECKPOINT_DIR} \
+# --resume checkpoints/chairs-gmflow_with_refine/checkpoint_latest.pth \
+# --batch_size 16 \
+# --val_dataset chairs sintel kitti \
+# --lr 4e-4 \
+# --image_size 384 512 \
+# --padding_factor 32 \
+# --upsample_factor 4 \
+# --num_scales 2 \
+# --attn_splits_list 2 8 \
+# --corr_radius_list -1 4 \
+# --prop_radius_list -1 1 \
+# --with_speed_metric \
+# --val_freq 10000 \
+# --save_ckpt_freq 10000 \
+# --num_steps 100000 \
+# 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
diff --git a/basicsr/archs/gmflow/utils/dist_utils.py b/basicsr/archs/gmflow/utils/dist_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..3c70f089225ad8cfb741f71809f4018c11711a72
--- /dev/null
+++ b/basicsr/archs/gmflow/utils/dist_utils.py
@@ -0,0 +1,99 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# https://github.com/open-mmlab/mmcv/blob/7540cf73ac7e5d1e14d0ffbd9b6759e83929ecfc/mmcv/runner/dist_utils.py
+
+import os
+import subprocess
+
+import torch
+import torch.multiprocessing as mp
+from torch import distributed as dist
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_mpi(backend, **kwargs):
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ # use MASTER_ADDR in the environment variable if it already exists
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
diff --git a/basicsr/archs/gmflow/utils/flow_viz.py b/basicsr/archs/gmflow/utils/flow_viz.py
new file mode 100755
index 0000000000000000000000000000000000000000..9b782c07841b27526ef8c9fa070b480a01545c31
--- /dev/null
+++ b/basicsr/archs/gmflow/utils/flow_viz.py
@@ -0,0 +1,291 @@
+# MIT License
+#
+# Copyright (c) 2018 Tom Runia
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to conditions.
+#
+# Author: Tom Runia
+# Date Created: 2018-08-03
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+
+def make_colorwheel():
+ '''
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ '''
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+ col = col + RY
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+ colorwheel[col:col + YG, 1] = 255
+ col = col + YG
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+ col = col + GC
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
+ colorwheel[col:col + CB, 2] = 255
+ col = col + CB
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+ col = col + BM
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
+ colorwheel[col:col + MR, 0] = 255
+ return colorwheel
+
+
+def flow_compute_color(u, v, convert_to_bgr=False):
+ '''
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ :param u: np.ndarray, input horizontal flow
+ :param v: np.ndarray, input vertical flow
+ :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
+ :return:
+ '''
+
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 1
+ f = fk - k0
+
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1 - f) * col0 + f * col1
+
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range?
+
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2 - i if convert_to_bgr else i
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
+
+ return flow_image
+
+
+def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
+ '''
+ Expects a two dimensional flow image of shape [H,W,2]
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ :param flow_uv: np.ndarray of shape [H,W,2]
+ :param clip_flow: float, maximum clipping value for flow
+ :return:
+ '''
+
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+
+ u = flow_uv[:, :, 0]
+ v = flow_uv[:, :, 1]
+
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+
+ return flow_compute_color(u, v, convert_to_bgr)
+
+
+UNKNOWN_FLOW_THRESH = 1e7
+SMALLFLOW = 0.0
+LARGEFLOW = 1e8
+
+
+def make_color_wheel():
+ """
+ Generate color wheel according Middlebury color code
+ :return: Color wheel
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
+ col += RY
+
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
+ colorwheel[col:col + YG, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
+ col += GC
+
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
+ colorwheel[col:col + CB, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
+ col += + BM
+
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+ colorwheel[col:col + MR, 0] = 255
+
+ return colorwheel
+
+
+def compute_color(u, v):
+ """
+ compute optical flow color map
+ :param u: optical flow horizontal map
+ :param v: optical flow vertical map
+ :return: optical flow in color code
+ """
+ [h, w] = u.shape
+ img = np.zeros([h, w, 3])
+ nanIdx = np.isnan(u) | np.isnan(v)
+ u[nanIdx] = 0
+ v[nanIdx] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+
+ k0 = np.floor(fk).astype(int)
+
+ k1 = k0 + 1
+ k1[k1 == ncols + 1] = 1
+ f = fk - k0
+
+ for i in range(0, np.size(colorwheel, 1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0 - 1] / 255
+ col1 = tmp[k1 - 1] / 255
+ col = (1 - f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
+
+ return img
+
+
+# from https://github.com/gengshan-y/VCN
+def flow_to_image(flow):
+ """
+ Convert flow into middlebury color code image
+ :param flow: optical flow map
+ :return: optical flow image in middlebury color
+ """
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ maxu = -999.
+ maxv = -999.
+ minu = 999.
+ minv = 999.
+
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
+ u[idxUnknow] = 0
+ v[idxUnknow] = 0
+
+ maxu = max(maxu, np.max(u))
+ minu = min(minu, np.min(u))
+
+ maxv = max(maxv, np.max(v))
+ minv = min(minv, np.min(v))
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+ maxrad = max(-1, np.max(rad))
+
+ u = u / (maxrad + np.finfo(float).eps)
+ v = v / (maxrad + np.finfo(float).eps)
+
+ img = compute_color(u, v)
+
+ idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
+ img[idx] = 0
+
+ return np.uint8(img)
+
+
+def save_vis_flow_tofile(flow, output_path):
+ vis_flow = flow_to_image(flow)
+ from PIL import Image
+ img = Image.fromarray(vis_flow)
+ img.save(output_path)
+
+
+def flow_tensor_to_image(flow):
+ """Used for tensorboard visualization"""
+ flow = flow.permute(1, 2, 0) # [H, W, 2]
+ flow = flow.detach().cpu().numpy()
+ flow = flow_to_image(flow) # [H, W, 3]
+ flow = np.transpose(flow, (2, 0, 1)) # [3, H, W]
+
+ return flow
diff --git a/basicsr/archs/gmflow/utils/frame_utils.py b/basicsr/archs/gmflow/utils/frame_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..9005ed1d2005d25456e620c467a5d688e8c0a783
--- /dev/null
+++ b/basicsr/archs/gmflow/utils/frame_utils.py
@@ -0,0 +1,131 @@
+import numpy as np
+from PIL import Image
+from os.path import *
+import re
+import cv2
+
+TAG_CHAR = np.array([202021.25], np.float32)
+
+
+def readFlow(fn):
+ """ Read .flo file in Middlebury format"""
+ # Code adapted from:
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
+ # print 'fn = %s'%(fn)
+ with open(fn, 'rb') as f:
+ magic = np.fromfile(f, np.float32, count=1)
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ return None
+ else:
+ w = np.fromfile(f, np.int32, count=1)
+ h = np.fromfile(f, np.int32, count=1)
+ # print 'Reading %d x %d flo file\n' % (w, h)
+ data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
+ # Reshape testdata into 3D array (columns, rows, bands)
+ # The reshape here is for visualization, the original code is (w,h,2)
+ return np.resize(data, (int(h), int(w), 2))
+
+
+def readPFM(file):
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header == b'PF':
+ color = True
+ elif header == b'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data
+
+
+def writeFlow(filename, uv, v=None):
+ """ Write optical flow to file.
+
+ If v is None, uv is assumed to contain both u and v channels,
+ stacked in depth.
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
+ """
+ nBands = 2
+
+ if v is None:
+ assert (uv.ndim == 3)
+ assert (uv.shape[2] == 2)
+ u = uv[:, :, 0]
+ v = uv[:, :, 1]
+ else:
+ u = uv
+
+ assert (u.shape == v.shape)
+ height, width = u.shape
+ f = open(filename, 'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ np.array(width).astype(np.int32).tofile(f)
+ np.array(height).astype(np.int32).tofile(f)
+ # arrange into matrix form
+ tmp = np.zeros((height, width * nBands))
+ tmp[:, np.arange(width) * 2] = u
+ tmp[:, np.arange(width) * 2 + 1] = v
+ tmp.astype(np.float32).tofile(f)
+ f.close()
+
+
+def readFlowKITTI(filename):
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
+ flow = flow[:, :, ::-1].astype(np.float32)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2 ** 15) / 64.0
+ return flow, valid
+
+
+def writeFlowKITTI(filename, uv):
+ uv = 64.0 * uv + 2 ** 15
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
+ cv2.imwrite(filename, uv[..., ::-1])
+
+
+def read_gen(file_name, pil=False):
+ ext = splitext(file_name)[-1]
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+ return Image.open(file_name)
+ elif ext == '.bin' or ext == '.raw':
+ return np.load(file_name)
+ elif ext == '.flo':
+ return readFlow(file_name).astype(np.float32)
+ elif ext == '.pfm':
+ flow = readPFM(file_name).astype(np.float32)
+ if len(flow.shape) == 2:
+ return flow
+ else:
+ return flow[:, :, :-1]
+ return []
diff --git a/basicsr/archs/gmflow/utils/logger.py b/basicsr/archs/gmflow/utils/logger.py
new file mode 100755
index 0000000000000000000000000000000000000000..07ab133fb0d4bbc9e84918fd276eb429f06d730a
--- /dev/null
+++ b/basicsr/archs/gmflow/utils/logger.py
@@ -0,0 +1,68 @@
+import torch
+
+from utils.flow_viz import flow_tensor_to_image
+
+
+class Logger:
+ def __init__(self, lr_scheduler,
+ summary_writer,
+ summary_freq=100,
+ start_step=0,
+ ):
+ self.lr_scheduler = lr_scheduler
+ self.total_steps = start_step
+ self.running_loss = {}
+ self.summary_writer = summary_writer
+ self.summary_freq = summary_freq
+
+ def print_training_status(self, mode='train'):
+
+ print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq))
+
+ for k in self.running_loss:
+ self.summary_writer.add_scalar(mode + '/' + k,
+ self.running_loss[k] / self.summary_freq, self.total_steps)
+ self.running_loss[k] = 0.0
+
+ def lr_summary(self):
+ lr = self.lr_scheduler.get_last_lr()[0]
+ self.summary_writer.add_scalar('lr', lr, self.total_steps)
+
+ def add_image_summary(self, img1, img2, flow_preds, flow_gt, mode='train',
+ ):
+ if self.total_steps % self.summary_freq == 0:
+ img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1)
+ img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard
+
+ flow_pred = flow_tensor_to_image(flow_preds[-1][0])
+ forward_flow_gt = flow_tensor_to_image(flow_gt[0])
+ flow_concat = torch.cat((torch.from_numpy(flow_pred),
+ torch.from_numpy(forward_flow_gt)), dim=-1)
+
+ concat = torch.cat((img_concat, flow_concat), dim=-2)
+
+ self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps)
+
+ def push(self, metrics, mode='train'):
+ self.total_steps += 1
+
+ self.lr_summary()
+
+ for key in metrics:
+ if key not in self.running_loss:
+ self.running_loss[key] = 0.0
+
+ self.running_loss[key] += metrics[key]
+
+ if self.total_steps % self.summary_freq == 0:
+ self.print_training_status(mode)
+ self.running_loss = {}
+
+ def write_dict(self, results):
+ for key in results:
+ tag = key.split('_')[0]
+ tag = tag + '/' + key
+ self.summary_writer.add_scalar(tag, results[key], self.total_steps)
+
+ def close(self):
+ self.summary_writer.close()
diff --git a/basicsr/archs/gmflow/utils/misc.py b/basicsr/archs/gmflow/utils/misc.py
new file mode 100755
index 0000000000000000000000000000000000000000..c2de906d8181e9e24d2f51e0be03a19c04960d06
--- /dev/null
+++ b/basicsr/archs/gmflow/utils/misc.py
@@ -0,0 +1,42 @@
+import os
+import numpy as np
+import sys
+import json
+
+
+def read_text_lines(filepath):
+ with open(filepath, 'r') as f:
+ lines = f.readlines()
+ lines = [l.rstrip() for l in lines]
+ return lines
+
+
+def check_path(path):
+ if not os.path.exists(path):
+ os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing
+
+
+def save_command(save_path, filename='command_train.txt'):
+ check_path(save_path)
+ command = sys.argv
+ save_file = os.path.join(save_path, filename)
+ # Save all training commands when resuming training
+ with open(save_file, 'a') as f:
+ f.write(' '.join(command))
+ f.write('\n\n')
+
+
+def save_args(args, filename='args.json'):
+ args_dict = vars(args)
+ check_path(args.checkpoint_dir)
+ save_path = os.path.join(args.checkpoint_dir, filename)
+
+ # Save all training args when resuming training
+ with open(save_path, 'a') as f:
+ json.dump(args_dict, f, indent=4, sort_keys=False)
+ f.write('\n\n')
+
+
+def int_list(s):
+ """Convert string to int list"""
+ return [int(x) for x in s.split(',')]
diff --git a/basicsr/archs/gmflow/utils/utils.py b/basicsr/archs/gmflow/utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..76f5518b7e5b769527907b31a1c1c00ba6cfe4f1
--- /dev/null
+++ b/basicsr/archs/gmflow/utils/utils.py
@@ -0,0 +1,58 @@
+import torch
+import torch.nn.functional as F
+
+
+class InputPadder:
+ """ Pads images such that dimensions are divisible by 8 """
+
+ def __init__(self, dims, mode='sintel', padding_factor=8):
+ self.ht, self.wd = dims[-2:]
+ pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
+ pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
+ if mode == 'sintel':
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
+ else:
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
+
+ def pad(self, *inputs):
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
+
+ def unpad(self, x):
+ ht, wd = x.shape[-2:]
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
+ return x[..., c[0]:c[1], c[2]:c[3]]
+
+
+def coords_grid(batch, ht, wd, normalize=False):
+ if normalize: # [-1, 1]
+ coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1,
+ 2 * torch.arange(wd) / (wd - 1) - 1)
+ else:
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W]
+
+
+def compute_out_of_boundary_mask(flow):
+ # flow: [B, 2, H, W]
+ assert flow.dim() == 4 and flow.size(1) == 2
+ b, _, h, w = flow.shape
+ init_coords = coords_grid(b, h, w).to(flow.device)
+ corres = init_coords + flow # [B, 2, H, W]
+
+ max_w = w - 1
+ max_h = h - 1
+
+ valid_mask = (corres[:, 0] >= 0) & (corres[:, 0] <= max_w) & (corres[:, 1] >= 0) & (corres[:, 1] <= max_h)
+
+ # in case very large flow
+ flow_mask = (flow[:, 0].abs() <= max_w) & (flow[:, 1].abs() <= max_h)
+
+ valid_mask = valid_mask & flow_mask
+
+ return valid_mask # [B, H, W]
+
+
+def count_parameters(model):
+ num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return num
diff --git a/basicsr/archs/gmflow_arch.py b/basicsr/archs/gmflow_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f650d1aad82a9e67fa25cc702927b7f957b4d2
--- /dev/null
+++ b/basicsr/archs/gmflow_arch.py
@@ -0,0 +1,82 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import pdb
+
+from basicsr.archs.gmflow.gmflow.gmflow import GMFlow
+
+
+class FlowGenerator(nn.Module):
+ """GM flow generation.
+
+ Args:
+ path (str): Pre-trained path. Default: None.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ """
+
+ def __init__(self,
+ path=None,
+ requires_grad=False,):
+ super().__init__()
+
+ self.model = GMFlow()
+
+ if path != None:
+ weights = torch.load(
+ path, map_location=lambda storage, loc: storage)['model']
+ self.model.load_state_dict(weights, strict=True)
+
+ if not requires_grad:
+ self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.model.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ def forward(self, im1, im2,
+ attn_splits_list=[2],
+ corr_radius_list=[-1],
+ prop_radius_list=[-1]):
+ """Forward function.
+
+ Args:
+ im1 (Tensor): Input tensor with shape (n, c, h, w).
+ im2 (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ assert im1.shape == im2.shape
+ N, C, H, W = im1.shape
+
+ im1 = (im1 + 1) / 2 * 255
+ im2 = (im2 + 1) / 2 * 255
+
+ flow = self.model(im1, im2,
+ attn_splits_list=attn_splits_list,
+ corr_radius_list=corr_radius_list,
+ prop_radius_list=prop_radius_list,
+ pred_bidir_flow=False)['flow_preds'][-1]
+ # backward_flow = flow[N:]
+
+ return flow
+
+
+if __name__ == '__main__':
+ h, w = 512, 512
+ # model = RAFT().cuda()
+ model = FlowGenerator(
+ load_path='../../weights/GMFlow/gmflow_sintel-0c07dcb3.pth').cuda()
+ model.eval()
+ print(model)
+
+ x = torch.randn((1, 3, h, w)).cuda()
+ y = torch.randn((1, 3, h, w)).cuda()
+ with torch.no_grad():
+ out = model(x, y)
+ pdb.set_trace()
+ print(out.shape)
diff --git a/basicsr/archs/keep_arch.py b/basicsr/archs/keep_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..aef5753dbc8a722cb86fe3b61a9cf9ccf23a887f
--- /dev/null
+++ b/basicsr/archs/keep_arch.py
@@ -0,0 +1,936 @@
+import math
+from re import T
+import numpy as np
+import pdb
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from typing import Optional, List
+from torch.profiler import profile, record_function, ProfilerActivity
+from collections import defaultdict
+
+# from gpu_mem_track import MemTracker
+from einops import rearrange, repeat
+
+from basicsr.archs.vqgan_arch import Encoder, VectorQuantizer, GumbelQuantizer, Generator, ResBlock
+from basicsr.archs.arch_util import flow_warp, resize_flow
+from basicsr.archs.gmflow_arch import FlowGenerator
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
+
+# gpu_tracker = MemTracker()
+
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)
+ ) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)),
+ device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats,
+ dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(
+ embed_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model - MLP
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.norm2 = nn.LayerNorm(embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ # self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.MultiheadAttention):
+ nn.init.xavier_uniform_(module.in_proj_weight)
+ nn.init.xavier_uniform_(module.out_proj.weight)
+ if module.in_proj_bias is not None:
+ nn.init.constant_(module.in_proj_bias, 0.)
+ nn.init.constant_(module.out_proj.bias, 0.)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+
+ # self attention
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+
+ # ffn
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+
+class Fuse_sft_block(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super().__init__()
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+ self.scale = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ self.shift = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Conv2d):
+ module.weight.data.zero_()
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ # print(enc_feat.shape, dec_feat.shape)
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+ scale = self.scale(enc_feat)
+ shift = self.shift(enc_feat)
+ residual = w * (dec_feat * scale + shift)
+ out = dec_feat + residual
+ return out
+
+
+class CrossFrameFusionLayer(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout,
+ activation_fn=activation_fn)
+
+ # Cross Frame Attention
+ self.attn = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ nn.init.zeros_(self.attn.to_out[0].weight.data)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ module.weight.data.zero_()
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+ def forward(self, curr_states, prev_states, residual=True):
+ B, C, H, W = curr_states.shape
+ curr_states = rearrange(curr_states, "b c h w -> b (h w) c")
+ prev_states = rearrange(prev_states, "b c h w -> b (h w) c")
+
+ if residual:
+ res = curr_states
+
+ curr_states = self.attn(curr_states, prev_states)
+ curr_states = self.norm1(curr_states)
+
+ if residual:
+ curr_states = curr_states + res
+ res = curr_states
+
+ curr_states = self.ff(curr_states)
+ curr_states = self.norm2(curr_states)
+
+ if residual:
+ curr_states = curr_states + res
+
+ curr_states = rearrange(curr_states, "b (h w) c -> b c h w", h=H)
+ return curr_states
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+
+ # SC-Attn
+ self.attn1 = SparseCausalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ self.norm1 = AdaLayerNorm(
+ dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # # Cross-Attn
+ # if cross_attention_dim is not None:
+ # self.attn2 = CrossAttention(
+ # query_dim=dim,
+ # cross_attention_dim=cross_attention_dim,
+ # heads=num_attention_heads,
+ # dim_head=attention_head_dim,
+ # dropout=dropout,
+ # bias=attention_bias,
+ # upcast_attention=upcast_attention,
+ # )
+ # else:
+ # self.attn2 = None
+
+ # if cross_attention_dim is not None:
+ # self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ # else:
+ # self.norm2 = None
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout,
+ activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+ # Temp-Attn
+ self.attn_temp = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+ self.norm_temp = AdaLayerNorm(
+ dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
+ # SparseCausal-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(
+ hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states,
+ attention_mask=attention_mask) + hidden_states
+ )
+ else:
+ hidden_states = self.attn1(
+ norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+
+ # if self.attn2 is not None:
+ # # Cross-Attention
+ # norm_hidden_states = (
+ # self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ # )
+ # hidden_states = (
+ # self.attn2(
+ # norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ # )
+ # + hidden_states
+ # )
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ # Temporal-Attention
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(
+ hidden_states)
+ )
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
+
+
+class SparseCausalAttention(CrossAttention):
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(
+ hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ former_frame_index = torch.arange(video_length) - 1
+ former_frame_index[0] = 0
+
+ # d = h*w
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
+ key = torch.cat([key[:, [0] * video_length],
+ key[:, former_frame_index]], dim=2)
+ key = rearrange(key, "b f d c -> (b f) d c")
+
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
+ value = torch.cat([value[:, [0] * video_length],
+ value[:, former_frame_index]], dim=2)
+ value = rearrange(value, "b f d c -> (b f) d c")
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(
+ attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(
+ self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(
+ query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(
+ query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(
+ query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+class KalmanFilter(nn.Module):
+ def __init__(self, emb_dim, num_attention_heads,
+ attention_head_dim, num_uncertainty_layers):
+ super().__init__()
+ self.uncertainty_estimator = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ emb_dim,
+ num_attention_heads,
+ attention_head_dim,
+ )
+ for d in range(num_uncertainty_layers)
+ ]
+ )
+
+ self.kalman_gain_calculator = nn.Sequential(
+ ResBlock(emb_dim, emb_dim),
+ ResBlock(emb_dim, emb_dim),
+ ResBlock(emb_dim, emb_dim),
+ nn.Conv2d(emb_dim, 1, kernel_size=1, padding=0),
+ nn.Sigmoid()
+ )
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+ def predict(self, z_hat, flow):
+ # Predict the next state based on the current state and flow (if available)
+ flow = rearrange(flow, "n c h w -> n h w c")
+ z_prime = flow_warp(z_hat, flow)
+ return z_prime
+
+ def update(self, z_code, z_prime, gain):
+ # Update the state and uncertainty based on the measurement and Kalman gain
+ z_hat = (1 - gain) * z_code + gain * z_prime
+ return z_hat
+
+ def calc_gain(self, z_codes):
+ assert z_codes.dim(
+ ) == 5, f"Expected z_codes to have ndim=5, but got ndim={z_codes.dim()}."
+ video_length = z_codes.shape[1]
+ height, width = z_codes.shape[3:5]
+
+ # Assume input shape of uncertainty_estimator to be [(b f) d c]
+ z_tmp = rearrange(z_codes, "b f c h w -> (b f) (h w) c")
+ h_codes = z_tmp
+ for block in self.uncertainty_estimator:
+ h_codes = block(h_codes, video_length=video_length)
+
+ h_codes = rearrange(
+ h_codes, "(b f) (h w) c -> (b f) c h w", h=height, f=video_length)
+ w_codes = self.kalman_gain_calculator(h_codes)
+
+ w_codes = rearrange(
+ w_codes, "(b f) c h w -> b f c h w", f=video_length)
+
+ # pdb.set_trace()
+ return w_codes
+
+
+def load_vqgan_checkpoint(model, vqgan_path, logger=None):
+ """Load VQGAN checkpoint into model components.
+
+ Args:
+ model: The model to load weights into
+ vqgan_path (str): Path to the VQGAN checkpoint
+ logger: Logger instance
+ """
+ if logger is None:
+ logger = get_root_logger()
+
+ # Load VQGAN checkpoint, load params_ema or params
+ ckpt = torch.load(vqgan_path, map_location='cpu', weights_only=True)
+ if 'params_ema' in ckpt:
+ state_dict = ckpt['params_ema']
+ logger.info(f'Loading VQGAN from: {vqgan_path} [params_ema]')
+ elif 'params' in ckpt:
+ state_dict = ckpt['params']
+ logger.info(f'Loading VQGAN from: {vqgan_path} [params]')
+ else:
+ raise ValueError(f'Wrong params in checkpoint: {vqgan_path}')
+
+ # Load encoder weights into both encoders
+ encoder_state_dict = {k.split('encoder.')[-1]: v for k, v in state_dict.items() if k.startswith('encoder.')}
+ model.encoder.load_state_dict(encoder_state_dict, strict=True)
+ model.hq_encoder.load_state_dict(encoder_state_dict, strict=True)
+
+ # Load quantizer weights
+ quantizer_state_dict = {k.split('quantize.')[-1]: v for k, v in state_dict.items() if k.startswith('quantize.')}
+ model.quantize.load_state_dict(quantizer_state_dict, strict=True)
+
+ # Load generator weights
+ generator_state_dict = {k.split('generator.')[-1]: v for k, v in state_dict.items() if k.startswith('generator.')}
+ model.generator.load_state_dict(generator_state_dict, strict=True)
+
+
+@ARCH_REGISTRY.register()
+class KEEP(nn.Module):
+ def __init__(self, img_size=512, nf=64, ch_mult=[1, 2, 2, 4, 4, 8], quantizer_type="nearest",
+ res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, vqgan_path=None,
+ dim_embd=512, n_head=8, n_layers=9, latent_size=256,
+ cft_list=['32', '64', '128', '256'], fix_modules=['quantize', 'generator'],
+ flownet_path=None, kalman_attn_head_dim=64, num_uncertainty_layers=4,
+ cond=1, cfa_list=[], cfa_nhead=4, cfa_dim=256,
+ cfa_nlayers=4, cross_residual=True,
+ temp_reg_list=[], mask_ratio=0.):
+ super().__init__()
+
+ self.cond = cond
+ self.cft_list = cft_list
+ self.cfa_list = cfa_list
+ self.temp_reg_list = temp_reg_list
+ self.use_residual = cross_residual
+ self.mask_ratio = mask_ratio
+ self.latent_size = latent_size
+ logger = get_root_logger()
+
+ # alignment
+ self.flownet = FlowGenerator(path=flownet_path)
+
+ # Kalman Filter
+ self.kalman_filter = KalmanFilter(
+ emb_dim=emb_dim,
+ num_attention_heads=n_head,
+ attention_head_dim=kalman_attn_head_dim,
+ num_uncertainty_layers=num_uncertainty_layers,
+ )
+
+ # Create encoders with same architecture
+ encoder_config = dict(
+ in_channels=3,
+ nf=nf,
+ emb_dim=emb_dim,
+ ch_mult=ch_mult,
+ num_res_blocks=res_blocks,
+ resolution=img_size,
+ attn_resolutions=attn_resolutions
+ )
+
+ self.hq_encoder = Encoder(**encoder_config)
+ self.encoder = Encoder(**encoder_config)
+
+ # VQGAN components
+ if quantizer_type == "nearest":
+ self.quantize = VectorQuantizer(codebook_size, emb_dim, beta)
+ elif quantizer_type == "gumbel":
+ self.quantize = GumbelQuantizer(
+ codebook_size, emb_dim, emb_dim, gumbel_straight_through, gumbel_kl_weight
+ )
+
+ self.generator = Generator(
+ nf=nf,
+ emb_dim=emb_dim,
+ ch_mult=ch_mult,
+ res_blocks=res_blocks,
+ img_size=img_size,
+ attn_resolutions=attn_resolutions
+ )
+
+ # Load VQGAN checkpoint if provided
+ if vqgan_path is not None:
+ load_vqgan_checkpoint(self, vqgan_path, logger)
+
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, dim_embd))
+ self.feat_emb = nn.Linear(emb_dim, dim_embd)
+
+ # transformer
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head,
+ dim_mlp=dim_embd*2, dropout=0.0) for _ in range(n_layers)])
+
+ # logits_predict head
+ self.idx_pred_layer = nn.Sequential(
+ nn.LayerNorm(dim_embd),
+ nn.Linear(dim_embd, codebook_size, bias=False))
+
+ self.channels = {
+ '16': 512,
+ '32': 256,
+ '64': 256,
+ '128': 128,
+ '256': 128,
+ '512': 64,
+ }
+
+ # after second residual block for > 16, before attn layer for ==16
+ self.fuse_encoder_block = {
+ '512': 2, '256': 5, '128': 8, '64': 11, '32': 14, '16': 18}
+ # after first residual block for > 16, before attn layer for ==16
+ self.fuse_generator_block = {
+ '16': 6, '32': 9, '64': 12, '128': 15, '256': 18, '512': 21}
+
+ # cross frame attention fusion
+ self.cfa = nn.ModuleDict()
+ for f_size in self.cfa_list:
+ in_ch = self.channels[f_size]
+ self.cfa[f_size] = CrossFrameFusionLayer(dim=in_ch,
+ num_attention_heads=cfa_nhead,
+ attention_head_dim=cfa_dim)
+
+ # Controllable Feature Transformation (CFT)
+ self.cft = nn.ModuleDict()
+ for f_size in self.cft_list:
+ in_ch = self.channels[f_size]
+ self.cft[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+ if fix_modules is not None:
+ for module in fix_modules:
+ for param in getattr(self, module).parameters():
+ param.requires_grad = False
+
+
+ def get_flow(self, x):
+ b, t, c, h, w = x.size()
+
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+ # Forward flow
+ with torch.no_grad():
+ flows = self.flownet(x_2, x_1).view(b, t - 1, 2, h, w)
+
+ return flows.detach()
+
+ def mask_by_ratio(self, x, mask_ratio=0.):
+ if mask_ratio == 0:
+ return x
+
+ # B F C H W
+ b, t, c, h, w = x.size()
+ d = h * w
+ x = rearrange(x, "b f c h w -> b f (h w) c")
+
+ len_keep = int(d * (1 - mask_ratio))
+ sample = torch.rand((b, t, d, 1), device=x.device).topk(
+ len_keep, dim=2).indices
+ mask = torch.zeros((b, t, d, 1), dtype=torch.bool, device=x.device)
+ mask.scatter_(dim=2, index=sample, value=True)
+
+ x = mask * x
+ x = rearrange(x, "b f (h w) c -> b f c h w", h=h)
+
+ return x
+
+ def forward(self, x, detach_16=True, early_feat=True, need_upscale=True):
+ """Forward function for KEEP.
+
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence of
+ shape (b, t, c, h, w).
+
+ Returns:
+ Tensor: Output HR sequence with shape (b, t, c, 4h, 4w).
+ """
+ video_length = x.shape[1]
+
+ if need_upscale:
+ x = rearrange(x, "b f c h w -> (b f) c h w")
+ x = F.interpolate(x, scale_factor=4, mode='bilinear')
+ x = rearrange(x, "(b f) c h w -> b f c h w", f=video_length)
+
+ b, t, c, h, w = x.size()
+ flows = self.get_flow(x) # (B, t-1, 2, H , W)
+
+ # ################### Encoder #####################
+ # BTCHW -> (BT)CHW
+ x = x.reshape(-1, c, h, w)
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size]
+ for f_size in self.cft_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = rearrange(x, "(b f) c h w -> b f c h w", f=t).detach()
+
+ lq_feat = x
+
+ # gpu_tracker.track('After encoder')
+ # ################### Kalman Filter ###############
+ z_codes = rearrange(x, "(b f) c h w -> b f c h w", f=t)
+ if self.training:
+ z_codes = self.mask_by_ratio(z_codes, self.mask_ratio)
+ gains = self.kalman_filter.calc_gain(z_codes)
+
+ outs = []
+ logits = []
+ cross_prev_feat = {}
+ gen_feat_dict = defaultdict(list)
+
+ cft_list = [self.fuse_generator_block[f_size]
+ for f_size in self.cft_list]
+
+ cfa_list = [self.fuse_generator_block[f_size]
+ for f_size in self.cfa_list]
+
+ temp_reg_list = [self.fuse_generator_block[f_size]
+ for f_size in self.temp_reg_list]
+
+ for i in range(video_length):
+ # print(f'Frame {i} ...')
+ if i == 0:
+ z_hat = z_codes[:, i, ...]
+ else:
+ z_prime = self.hq_encoder(
+ self.kalman_filter.predict(prev_out.detach(), flows[:, i-1, ...]))
+ z_hat = self.kalman_filter.update(
+ z_codes[:, i, ...], z_prime, gains[:, i, ...])
+
+ # ################# Transformer ###################
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1, b, 1)
+ # BCHW -> BC(HW) -> (HW)BC
+ query_emb = self.feat_emb(z_hat.flatten(2).permute(2, 0, 1))
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb)
+
+ # output logits
+ logit = self.idx_pred_layer(query_emb).permute(
+ 1, 0, 2) # (hw)bn -> b(hw)n
+ logits.append(logit)
+
+ # ################# Quantization ###################
+ code_h = int(np.sqrt(self.latent_size))
+ soft_one_hot = F.softmax(logit, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(
+ top_idx, shape=[b, code_h, code_h, 256])
+
+ if detach_16:
+ # for training stage III
+ quant_feat = quant_feat.detach()
+ else:
+ # preserve gradients for stage II
+ quant_feat = query_emb + (quant_feat - query_emb).detach()
+
+ # ################## Generator ####################
+ x = quant_feat
+
+ for j, block in enumerate(self.generator.blocks):
+ x = block(x)
+
+ if j in cft_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ # pdb.set_trace()
+ x = self.cft[f_size](
+ enc_feat_dict[f_size][:, i, ...], x, self.cond)
+
+ if j in cfa_list:
+ f_size = str(x.shape[-1])
+
+ if i == 0:
+ cross_prev_feat[f_size] = x
+ # print(f_size)
+ else:
+ # pdb.set_trace()
+ prev_fea = cross_prev_feat[f_size]
+ x = self.cfa[f_size](
+ x, prev_fea, residual=self.use_residual)
+ cross_prev_feat[f_size] = x
+
+ if j in temp_reg_list:
+ f_size = str(x.shape[-1])
+ gen_feat_dict[f_size].append(x)
+
+ prev_out = x # B C H W
+ outs.append(prev_out)
+
+ for f_size, feat in gen_feat_dict.items():
+ gen_feat_dict[f_size] = torch.stack(feat, dim=1) # bfchw
+
+ # Convert defaultdict to regular dict before returning
+ gen_feat_dict = dict(gen_feat_dict)
+
+ logits = torch.stack(logits, dim=1) # b(hw)n -> bf(hw)n
+ logits = rearrange(logits, "b f l n -> (b f) l n")
+ outs = torch.stack(outs, dim=1) # bfchw
+ if self.training:
+ if early_feat:
+ return outs, logits, lq_feat, gen_feat_dict
+ else:
+ return outs, gen_feat_dict
+ else:
+ return outs
+
+
+def count_parameters(model):
+ # Initialize counters
+ total_params = 0
+ sub_module_params = {}
+
+ # Loop through all the modules in the model
+ for name, module in model.named_children():
+ # if len(list(module.children())) == 0: # Check if it's a leaf module
+ params = sum(p.numel() for p in module.parameters())
+ total_params += params
+ sub_module_params[name] = params
+
+ return total_params, sub_module_params
+
+
+if __name__ == '__main__':
+ import time
+ batch_size = 1
+ video_length = 4
+ height = 128
+ width = 128
+
+ model = KEEP(
+ img_size=512,
+ emb_dim=256,
+ ch_mult=[1, 2, 2, 4, 4, 8],
+ dim_embd=512,
+ n_head=8,
+ n_layers=4,
+ codebook_size=1024,
+ cft_list=[],
+ fix_modules=['generator', 'quantize', 'flownet', 'cft', 'hq_encoder',
+ 'encoder', 'feat_emb', 'ft_layers', 'idx_pred_layer'],
+ flownet_path="../../weights/GMFlow/gmflow_sintel-0c07dcb3.pth",
+ kalman_attn_head_dim=32,
+ num_uncertainty_layers=3,
+ cond=0,
+ cfa_list=['32'],
+ cfa_nhead=4,
+ cfa_dim=256,
+ temp_reg_list=['64'],
+ ).cuda()
+
+ total_params = sum(map(lambda x: x.numel(), model.parameters()))
+ print(f"Total parameters in the model: {total_params / 1e6:.2f} M")
+
+ dummy_input = torch.randn((1, 20, 3, 128, 128)).cuda()
+
+ start_time = time.time()
+
+ with torch.no_grad():
+ for _ in range(100):
+ out = model(dummy_input)
+ elapsed_time = time.time() - start_time
+
+ print(f"Forward pass time: {elapsed_time / 100 / 20 * 1000:.2f} ms")
+ print(out.shape)
diff --git a/basicsr/archs/rrdbnet_arch.py b/basicsr/archs/rrdbnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..49a2d6c204557cba53ada7550deb587541855cfb
--- /dev/null
+++ b/basicsr/archs/rrdbnet_arch.py
@@ -0,0 +1,119 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+ """Residual Dense Block.
+
+ Used in RRDB block in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ """Residual in Residual Dense Block.
+
+ Used in RRDB-Net in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+ """Networks consisting of Residual in Residual Dense Block, which is used
+ in ESRGAN.
+
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+ We extend ESRGAN for scale x2 and scale x1.
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64
+ num_block (int): Block number in the trunk network. Defaults: 23
+ num_grow_ch (int): Channels for each growth. Default: 32.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+ super(RRDBNet, self).__init__()
+ self.scale = scale
+ if scale == 2:
+ num_in_ch = num_in_ch * 4
+ elif scale == 1:
+ num_in_ch = num_in_ch * 16
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ if self.scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ elif self.scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ else:
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
\ No newline at end of file
diff --git a/basicsr/archs/spectral_norm_arch.py b/basicsr/archs/spectral_norm_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..91ed53d70123bc01306a7b38b617f54d116c8a5f
--- /dev/null
+++ b/basicsr/archs/spectral_norm_arch.py
@@ -0,0 +1,288 @@
+"""
+Spectral Normalization from https://arxiv.org/abs/1802.05957
+"""
+import torch
+from torch.nn.functional import normalize
+
+
+class SpectralNorm(object):
+ # Invariant before and after each forward call:
+ # u = normalize(W @ v)
+ # NB: At initialization, this invariant is not enforced
+
+ _version = 1
+
+ # At version 1:
+ # made `W` not a buffer,
+ # added `v` as a buffer, and
+ # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
+
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
+ self.name = name
+ self.dim = dim
+ if n_power_iterations <= 0:
+ raise ValueError(
+ 'Expected n_power_iterations to be positive, but '
+ 'got n_power_iterations={}'.format(n_power_iterations))
+ self.n_power_iterations = n_power_iterations
+ self.eps = eps
+
+ def reshape_weight_to_matrix(self, weight):
+ weight_mat = weight
+ if self.dim != 0:
+ # permute dim to front
+ weight_mat = weight_mat.permute(
+ self.dim,
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
+ height = weight_mat.size(0)
+ return weight_mat.reshape(height, -1)
+
+ def compute_weight(self, module, do_power_iteration):
+ # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
+ # updated in power iteration **in-place**. This is very important
+ # because in `DataParallel` forward, the vectors (being buffers) are
+ # broadcast from the parallelized module to each module replica,
+ # which is a new module object created on the fly. And each replica
+ # runs its own spectral norm power iteration. So simply assigning
+ # the updated vectors to the module this function runs on will cause
+ # the update to be lost forever. And the next time the parallelized
+ # module is replicated, the same randomly initialized vectors are
+ # broadcast and used!
+ #
+ # Therefore, to make the change propagate back, we rely on two
+ # important behaviors (also enforced via tests):
+ # 1. `DataParallel` doesn't clone storage if the broadcast tensor
+ # is already on correct device; and it makes sure that the
+ # parallelized module is already on `device[0]`.
+ # 2. If the out tensor in `out=` kwarg has correct shape, it will
+ # just fill in the values.
+ # Therefore, since the same power iteration is performed on all
+ # devices, simply updating the tensors in-place will make sure that
+ # the module replica on `device[0]` will update the _u vector on the
+ # parallized module (by shared storage).
+ #
+ # However, after we update `u` and `v` in-place, we need to **clone**
+ # them before using them to normalize the weight. This is to support
+ # backproping through two forward passes, e.g., the common pattern in
+ # GAN training: loss = D(real) - D(fake). Otherwise, engine will
+ # complain that variables needed to do backward for the first forward
+ # (i.e., the `u` and `v` vectors) are changed in the second forward.
+ weight = getattr(module, self.name + '_orig')
+ u = getattr(module, self.name + '_u')
+ v = getattr(module, self.name + '_v')
+ weight_mat = self.reshape_weight_to_matrix(weight)
+
+ if do_power_iteration:
+ with torch.no_grad():
+ for _ in range(self.n_power_iterations):
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
+ # are the first left and right singular vectors.
+ # This power iteration produces approximations of `u` and `v`.
+ v = normalize(torch.mv(weight_mat.t(), u),
+ dim=0,
+ eps=self.eps,
+ out=v)
+ u = normalize(torch.mv(weight_mat, v),
+ dim=0,
+ eps=self.eps,
+ out=u)
+ if self.n_power_iterations > 0:
+ # See above on why we need to clone
+ u = u.clone()
+ v = v.clone()
+
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
+ weight = weight / sigma
+ return weight
+
+ def remove(self, module):
+ with torch.no_grad():
+ weight = self.compute_weight(module, do_power_iteration=False)
+ delattr(module, self.name)
+ delattr(module, self.name + '_u')
+ delattr(module, self.name + '_v')
+ delattr(module, self.name + '_orig')
+ module.register_parameter(self.name,
+ torch.nn.Parameter(weight.detach()))
+
+ def __call__(self, module, inputs):
+ setattr(
+ module, self.name,
+ self.compute_weight(module, do_power_iteration=module.training))
+
+ def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
+ # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
+ # (the invariant at top of this class) and `u @ W @ v = sigma`.
+ # This uses pinverse in case W^T W is not invertible.
+ v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
+ weight_mat.t(), u.unsqueeze(1)).squeeze(1)
+ return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
+
+ @staticmethod
+ def apply(module, name, n_power_iterations, dim, eps):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ raise RuntimeError(
+ "Cannot register two spectral_norm hooks on "
+ "the same parameter {}".format(name))
+
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
+ weight = module._parameters[name]
+
+ with torch.no_grad():
+ weight_mat = fn.reshape_weight_to_matrix(weight)
+
+ h, w = weight_mat.size()
+ # randomly initialize `u` and `v`
+ u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
+ v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
+
+ delattr(module, fn.name)
+ module.register_parameter(fn.name + "_orig", weight)
+ # We still need to assign weight back as fn.name because all sorts of
+ # things may assume that it exists, e.g., when initializing weights.
+ # However, we can't directly assign as it could be an nn.Parameter and
+ # gets added as a parameter. Instead, we register weight.data as a plain
+ # attribute.
+ setattr(module, fn.name, weight.data)
+ module.register_buffer(fn.name + "_u", u)
+ module.register_buffer(fn.name + "_v", v)
+
+ module.register_forward_pre_hook(fn)
+
+ module._register_state_dict_hook(SpectralNormStateDictHook(fn))
+ module._register_load_state_dict_pre_hook(
+ SpectralNormLoadStateDictPreHook(fn))
+ return fn
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormLoadStateDictPreHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ # For state_dict with version None, (assuming that it has gone through at
+ # least one training forward), we have
+ #
+ # u = normalize(W_orig @ v)
+ # W = W_orig / sigma, where sigma = u @ W_orig @ v
+ #
+ # To compute `v`, we solve `W_orig @ x = u`, and let
+ # v = x / (u @ W_orig @ x) * (W / W_orig).
+ def __call__(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ fn = self.fn
+ version = local_metadata.get('spectral_norm',
+ {}).get(fn.name + '.version', None)
+ if version is None or version < 1:
+ with torch.no_grad():
+ weight_orig = state_dict[prefix + fn.name + '_orig']
+ # weight = state_dict.pop(prefix + fn.name)
+ # sigma = (weight_orig / weight).mean()
+ weight_mat = fn.reshape_weight_to_matrix(weight_orig)
+ u = state_dict[prefix + fn.name + '_u']
+ # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
+ # state_dict[prefix + fn.name + '_v'] = v
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormStateDictHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ def __call__(self, module, state_dict, prefix, local_metadata):
+ if 'spectral_norm' not in local_metadata:
+ local_metadata['spectral_norm'] = {}
+ key = self.fn.name + '.version'
+ if key in local_metadata['spectral_norm']:
+ raise RuntimeError(
+ "Unexpected key in metadata['spectral_norm']: {}".format(key))
+ local_metadata['spectral_norm'][key] = self.fn._version
+
+
+def spectral_norm(module,
+ name='weight',
+ n_power_iterations=1,
+ eps=1e-12,
+ dim=None):
+ r"""Applies spectral normalization to a parameter in the given module.
+
+ .. math::
+ \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
+
+ Spectral normalization stabilizes the training of discriminators (critics)
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
+ power iteration method. If the dimension of the weight tensor is greater
+ than 2, it is reshaped to 2D in power iteration method to get spectral
+ norm. This is implemented via a hook that calculates spectral norm and
+ rescales weight before every :meth:`~Module.forward` call.
+
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
+
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
+
+ Args:
+ module (nn.Module): containing module
+ name (str, optional): name of weight parameter
+ n_power_iterations (int, optional): number of power iterations to
+ calculate spectral norm
+ eps (float, optional): epsilon for numerical stability in
+ calculating norms
+ dim (int, optional): dimension corresponding to number of outputs,
+ the default is ``0``, except for modules that are instances of
+ ConvTranspose{1,2,3}d, when it is ``1``
+
+ Returns:
+ The original module with the spectral norm hook
+
+ Example::
+
+ >>> m = spectral_norm(nn.Linear(20, 40))
+ >>> m
+ Linear(in_features=20, out_features=40, bias=True)
+ >>> m.weight_u.size()
+ torch.Size([40])
+
+ """
+ if dim is None:
+ if isinstance(module,
+ (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d)):
+ dim = 1
+ else:
+ dim = 0
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
+ return module
+
+
+def remove_spectral_norm(module, name='weight'):
+ r"""Removes the spectral normalization reparameterization from a module.
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+
+ Example:
+ >>> m = spectral_norm(nn.Linear(40, 10))
+ >>> remove_spectral_norm(m)
+ """
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ return module
+
+ raise ValueError("spectral_norm of '{}' not found in {}".format(
+ name, module))
+
+
+def use_spectral_norm(module, use_sn=False):
+ if use_sn:
+ return spectral_norm(module)
+ return module
diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d193a48ce0eecd955d6c654ea63dc6811a09cd2
--- /dev/null
+++ b/basicsr/archs/vgg_arch.py
@@ -0,0 +1,161 @@
+import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+ 'vgg11': [
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+ 'pool5'
+ ],
+ 'vgg13': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg16': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+ 'pool5'
+ ],
+ 'vgg19': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+ ]
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if 'conv' in name:
+ position = name.replace('conv', '')
+ names_bn.append('bn' + position)
+ return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace('_bn', '')]
+ if 'bn' in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage, weights_only=True)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[:max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features):
+ if 'pool' in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+ output = {}
+
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
diff --git a/basicsr/archs/vqgan_arch.py b/basicsr/archs/vqgan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c603bfa8c56b360e0d51c7c7dd61ced7bae3ae16
--- /dev/null
+++ b/basicsr/archs/vqgan_arch.py
@@ -0,0 +1,597 @@
+'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+from basicsr.archs.spectral_norm_arch import spectral_norm as _spectral_norm
+
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+ return x*torch.sigmoid(x)
+
+
+# Define VQVAE classes
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ self.beta = beta
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 /
+ self.codebook_size, 1.0 / self.codebook_size)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.emb_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+ mean_distance = torch.mean(d)
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+ # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+ # [0-1], higher score, higher confidence
+ # min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.codebook_size).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, {
+ "perplexity": perplexity,
+ "min_encodings": min_encodings,
+ "min_encoding_indices": min_encoding_indices,
+ "mean_distance": mean_distance
+ }
+
+ def get_codebook_feat(self, indices, shape):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1, 1)
+ min_encodings = torch.zeros(
+ indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+ super().__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+ # projects last encoder layer to quantized logits
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1)
+ self.embed = nn.Embedding(codebook_size, emb_dim)
+
+ def forward(self, z):
+ hard = self.straight_through if self.training else True
+
+ logits = self.proj(z)
+
+ soft_one_hot = F.gumbel_softmax(
+ logits, tau=self.temperature, dim=1, hard=hard)
+
+ z_q = torch.einsum("b n h w, n d -> b d h w",
+ soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * \
+ torch.sum(
+ qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+ return z_q, diff, {
+ "min_encoding_indices": min_encoding_indices
+ }
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels,
+ kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels,
+ kernel_size=3, stride=1, padding=1)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels,
+ kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1)
+ k = k.reshape(b, c, h*w)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.attn_resolutions = attn_resolutions
+
+ curr_res = self.resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+
+ blocks = []
+ # initial convultion
+ blocks.append(nn.Conv2d(in_channels, nf,
+ kernel_size=3, stride=1, padding=1))
+
+ # residual and downsampling blocks, with attention on smaller res (16x16)
+ for i in range(self.num_resolutions):
+ block_in_ch = nf * in_ch_mult[i]
+ block_out_ch = nf * ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+ if curr_res in attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != self.num_resolutions - 1:
+ blocks.append(Downsample(block_in_ch))
+ curr_res = curr_res // 2
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ # normalise and convert to latent size
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim,
+ kernel_size=3, stride=1, padding=1))
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class Generator(nn.Module):
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.ch_mult = ch_mult
+ self.num_resolutions = len(self.ch_mult)
+ self.num_res_blocks = res_blocks
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.in_channels = emb_dim
+ self.out_channels = 3
+ block_in_ch = self.nf * self.ch_mult[-1]
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+ blocks = []
+ # initial conv
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch,
+ kernel_size=3, stride=1, padding=1))
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ for i in reversed(range(self.num_resolutions)):
+ block_out_ch = self.nf * self.ch_mult[i]
+
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+
+ if curr_res in self.attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != 0:
+ blocks.append(Upsample(block_in_ch))
+ curr_res = curr_res * 2
+
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels,
+ kernel_size=3, stride=1, padding=1))
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta # 0.25
+ self.quantize = VectorQuantizer(
+ self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight
+ )
+ self.generator = Generator(
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+
+ if model_path is not None:
+ ckpt = torch.load(model_path, map_location='cpu', weights_only=True)
+ if 'params_ema' in ckpt:
+ self.load_state_dict(ckpt['params_ema'])
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
+ elif 'params' in ckpt:
+ self.load_state_dict(ckpt['params'])
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
+ else:
+ raise ValueError(f'Wrong params!')
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ x = self.generator(quant)
+ return x, codebook_loss, quant_stats
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+ super().__init__()
+
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2,
+ padding=1), nn.LeakyReLU(0.2, True)]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n, 8)
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult,
+ kernel_size=4, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n_layers, 8)
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult,
+ kernel_size=4, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ if model_path is not None:
+ ckpt = torch.load(model_path, map_location='cpu')
+ if 'params_d' in chkpt:
+ self.load_state_dict(ckpt['params_d'])
+ elif 'params' in chkpt:
+ self.load_state_dict(ckpt['params'])
+ else:
+ raise ValueError(f'Wrong params!')
+
+ def forward(self, x):
+ return self.main(x)
+
+
+@ARCH_REGISTRY.register()
+class VQHQEncoder(nn.Module):
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None, params='params'):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta # 0.25
+ self.quantize = VectorQuantizer(
+ self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight
+ )
+
+ if model_path is not None:
+ self.load_state_dict(torch.load(
+ model_path, map_location='cpu', weights_only=True)[params], strict=False)
+ logger.info(
+ f'VQGAN for latent calculation is loaded from: {model_path} [{params}]')
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ return x, codebook_loss, quant_stats
+
+
+@ARCH_REGISTRY.register()
+class Discriminator3D(nn.Module):
+ def __init__(self,
+ in_channels=3,
+ nf=32,
+ use_sigmoid=False,
+ use_spectral_norm=True,):
+ super().__init__()
+ self.use_sigmoid = use_sigmoid
+
+ self.layers = nn.Sequential(
+ spectral_norm(
+ nn.Conv3d(in_channels=in_channels,
+ out_channels=nf * 1,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=1,
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(64, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 1,
+ nf * 2,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(128, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 2,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2)))
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, xs):
+ # B, T, C, H, W (new)
+ xs_t = torch.transpose(xs, 1, 2)
+ feat = self.layers(xs_t)
+ if self.use_sigmoid:
+ feat = torch.sigmoid(feat)
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
+ return out
+
+
+def spectral_norm(module, mode=True):
+ if mode:
+ return _spectral_norm(module)
+ return module
diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6adb4bb6a926af7a46aaec4794eee95fda02a33
--- /dev/null
+++ b/basicsr/data/__init__.py
@@ -0,0 +1,100 @@
+import importlib
+import numpy as np
+import random
+import torch
+import torch.utils.data
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+from basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ['build_dataset', 'build_dataloader']
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+ """Build dataset from options.
+
+ Args:
+ dataset_opt (dict): Configuration for dataset. It must constain:
+ name (str): Dataset name.
+ type (str): Dataset type.
+ """
+ dataset_opt = deepcopy(dataset_opt)
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
+ logger = get_root_logger()
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
+ return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+ """Build dataloader.
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset.
+ dataset_opt (dict): Dataset options. It contains the following keys:
+ phase (str): 'train' or 'val'.
+ num_worker_per_gpu (int): Number of workers for each GPU.
+ batch_size_per_gpu (int): Training batch size for each GPU.
+ num_gpu (int): Number of GPUs. Used only in the train phase.
+ Default: 1.
+ dist (bool): Whether in distributed training. Used only in the train
+ phase. Default: False.
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
+ seed (int | None): Seed. Default: None
+ """
+ phase = dataset_opt['phase']
+ rank, _ = get_dist_info()
+ if phase == 'train':
+ if dist: # distributed training
+ batch_size = dataset_opt['batch_size_per_gpu']
+ num_workers = dataset_opt['num_worker_per_gpu']
+ else: # non-distributed training
+ multiplier = 1 if num_gpu == 0 else num_gpu
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
+ dataloader_args = dict(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=sampler,
+ drop_last=True)
+ if sampler is None:
+ dataloader_args['shuffle'] = True
+ dataloader_args['worker_init_fn'] = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ elif phase in ['val', 'test']: # validation
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ else:
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
+
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
+
+ prefetch_mode = dataset_opt.get('prefetch_mode')
+ if prefetch_mode == 'cpu': # CPUPrefetcher
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
+ logger = get_root_logger()
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+ else:
+ # prefetch_mode=None: Normal dataloader
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+ return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # Set the worker seed to num_workers * rank + worker_id + seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/basicsr/data/data_sampler.py b/basicsr/data/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2
--- /dev/null
+++ b/basicsr/data/data_sampler.py
@@ -0,0 +1,48 @@
+import math
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ Modified from torch.utils.data.distributed.DistributedSampler
+ Support enlarging the dataset for iteration-based training, for saving
+ time when restart the dataloader after each epoch
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
+ num_replicas (int | None): Number of processes participating in
+ the training. It is usually the world_size.
+ rank (int | None): Rank of the current process within num_replicas.
+ ratio (int): Enlarging ratio. Default: 1.
+ """
+
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.total_size, generator=g).tolist()
+
+ dataset_size = len(self.dataset)
+ indices = [v % dataset_size for v in indices]
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..44a719108b98d77ee549bb5f6bc4c4d95b1c0193
--- /dev/null
+++ b/basicsr/data/data_util.py
@@ -0,0 +1,392 @@
+import cv2
+import math
+import numpy as np
+import torch
+from os import path as osp
+from PIL import Image, ImageDraw
+from torch.nn import functional as F
+
+from basicsr.data.transforms import mod_crop
+from basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1):
+ """Read a sequence of images from a given folder path.
+
+ Args:
+ path (list[str] | str): List of image paths or image folder path.
+ require_mod_crop (bool): Require mod crop for each image.
+ Default: False.
+ scale (int): Scale factor for mod_crop. Default: 1.
+
+ Returns:
+ Tensor: size (t, c, h, w), RGB, [0, 1].
+ """
+ if isinstance(path, list):
+ img_paths = path
+ else:
+ img_paths = sorted(list(scandir(path, full_path=True)))
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+ if require_mod_crop:
+ imgs = [mod_crop(img, scale) for img in imgs]
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+ imgs = torch.stack(imgs, dim=0)
+ return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
+ """Generate an index list for reading `num_frames` frames from a sequence
+ of images.
+
+ Args:
+ crt_idx (int): Current center index.
+ max_frame_num (int): Max number of the sequence of images (from 1).
+ num_frames (int): Reading num_frames frames.
+ padding (str): Padding mode, one of
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+ Examples: current_idx = 0, num_frames = 5
+ The generated frame indices under different padding mode:
+ replicate: [0, 0, 0, 1, 2]
+ reflection: [2, 1, 0, 1, 2]
+ reflection_circle: [4, 3, 0, 1, 2]
+ circle: [3, 4, 0, 1, 2]
+
+ Returns:
+ list[int]: A list of indices.
+ """
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
+
+ max_frame_num = max_frame_num - 1 # start from 0
+ num_pad = num_frames // 2
+
+ indices = []
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ pad_idx = 0
+ elif padding == 'reflection':
+ pad_idx = -i
+ elif padding == 'reflection_circle':
+ pad_idx = crt_idx + num_pad - i
+ else:
+ pad_idx = num_frames + i
+ elif i > max_frame_num:
+ if padding == 'replicate':
+ pad_idx = max_frame_num
+ elif padding == 'reflection':
+ pad_idx = max_frame_num * 2 - i
+ elif padding == 'reflection_circle':
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+ else:
+ pad_idx = i - num_frames
+ else:
+ pad_idx = i
+ indices.append(pad_idx)
+ return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+ """Generate paired paths from lmdb files.
+
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+ lq.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records
+ 1)image name (with extension),
+ 2)image shape,
+ 3)compression level, separated by a white space.
+ Example: `baboon.png (120,125,3) 1`
+
+ We use the image name without extension as the lmdb key.
+ Note that we use the same key for the corresponding lq and gt images.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ Note that this key is different from lmdb keys.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
+ f'formats. But received {input_key}: {input_folder}; '
+ f'{gt_key}: {gt_folder}')
+ # ensure that the two meta_info files are the same
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
+ else:
+ paths = []
+ for lmdb_key in sorted(input_lmdb_keys):
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
+ return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+ """Generate paired paths from an meta information file.
+
+ Each line in the meta information file contains the image names and
+ image shape (usually for gt), separated by a white space.
+
+ Example of an meta information file:
+ ```
+ 0001_s001.png (480,480,3)
+ 0001_s002.png (480,480,3)
+ ```
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ meta_info_file (str): Path to the meta information file.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ with open(meta_info_file, 'r') as fin:
+ gt_names = [line.split(' ')[0] for line in fin]
+
+ paths = []
+ for gt_name in gt_names:
+ basename, ext = osp.splitext(osp.basename(gt_name))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ gt_path = osp.join(gt_folder, gt_name)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+ """Generate paired paths from folders.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ input_paths = list(scandir(input_folder))
+ gt_paths = list(scandir(gt_folder))
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
+ f'{len(input_paths)}, {len(gt_paths)}.')
+ paths = []
+ for gt_path in gt_paths:
+ basename, ext = osp.splitext(osp.basename(gt_path))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
+ gt_path = osp.join(gt_folder, gt_path)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paths_from_folder(folder):
+ """Generate paths from folder.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+
+ paths = list(scandir(folder))
+ paths = [osp.join(folder, path) for path in paths]
+ return paths
+
+
+def paths_from_lmdb(folder):
+ """Generate paths from lmdb.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ if not folder.endswith('.lmdb'):
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
+ paths = [line.split('.')[0] for line in fin]
+ return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+ """Generate Gaussian kernel used in `duf_downsample`.
+
+ Args:
+ kernel_size (int): Kernel size. Default: 13.
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+ Returns:
+ np.array: The Gaussian kernel.
+ """
+ from scipy.ndimage import filters as filters
+ kernel = np.zeros((kernel_size, kernel_size))
+ # set element at the middle to one, a dirac delta
+ kernel[kernel_size // 2, kernel_size // 2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter
+ return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+ """Downsamping with Gaussian kernel used in the DUF official code.
+
+ Args:
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+ kernel_size (int): Kernel size. Default: 13.
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+ Default: 4.
+
+ Returns:
+ Tensor: DUF downsampled frames.
+ """
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
+
+ squeeze_flag = False
+ if x.ndim == 4:
+ squeeze_flag = True
+ x = x.unsqueeze(0)
+ b, t, c, h, w = x.size()
+ x = x.view(-1, 1, h, w)
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
+
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+ x = F.conv2d(x, gaussian_filter, stride=scale)
+ x = x[:, :, 2:-2, 2:-2]
+ x = x.view(b, t, c, x.size(2), x.size(3))
+ if squeeze_flag:
+ x = x.squeeze(0)
+ return x
+
+
+def brush_stroke_mask(img, color=(255,255,255)):
+ min_num_vertex = 8
+ max_num_vertex = 28
+ mean_angle = 2*math.pi / 5
+ angle_range = 2*math.pi / 12
+ # training large mask ratio (training setting)
+ min_width = 30
+ max_width = 70
+ # very large mask ratio (test setting and refine after 200k)
+ # min_width = 80
+ # max_width = 120
+ def generate_mask(H, W, img=None):
+ average_radius = math.sqrt(H*H+W*W) / 8
+ mask = Image.new('RGB', (W, H), 0)
+ if img is not None: mask = img # Image.fromarray(img)
+
+ for _ in range(np.random.randint(1, 4)):
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
+ angles = []
+ vertex = []
+ for i in range(num_vertex):
+ if i % 2 == 0:
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
+ else:
+ angles.append(np.random.uniform(angle_min, angle_max))
+
+ h, w = mask.size
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
+ for i in range(num_vertex):
+ r = np.clip(
+ np.random.normal(loc=average_radius, scale=average_radius//2),
+ 0, 2*average_radius)
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
+ vertex.append((int(new_x), int(new_y)))
+
+ draw = ImageDraw.Draw(mask)
+ width = int(np.random.uniform(min_width, max_width))
+ draw.line(vertex, fill=color, width=width)
+ for v in vertex:
+ draw.ellipse((v[0] - width//2,
+ v[1] - width//2,
+ v[0] + width//2,
+ v[1] + width//2),
+ fill=color)
+
+ return mask
+
+ width, height = img.size
+ mask = generate_mask(height, width, img)
+ return mask
+
+
+def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
+ """Generate a random free form mask with configuration.
+ Args:
+ config: Config should have configuration including IMG_SHAPES,
+ VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
+ Returns:
+ tuple: (top, left, height, width)
+ Link:
+ https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
+ """
+ height = shape[0]
+ width = shape[1]
+ mask = np.zeros((height, width), np.float32)
+ times = np.random.randint(times-5, times)
+ for i in range(times):
+ start_x = np.random.randint(width)
+ start_y = np.random.randint(height)
+ for j in range(1 + np.random.randint(5)):
+ angle = 0.01 + np.random.randint(max_angle)
+ if i % 2 == 0:
+ angle = 2 * 3.1415926 - angle
+ length = 10 + np.random.randint(max_len-20, max_len)
+ brush_w = 5 + np.random.randint(max_width-30, max_width)
+ end_x = (start_x + length * np.sin(angle)).astype(np.int32)
+ end_y = (start_y + length * np.cos(angle)).astype(np.int32)
+ cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
+ start_x, start_y = end_x, end_y
+ return mask.astype(np.float32)
\ No newline at end of file
diff --git a/basicsr/data/degradations.py b/basicsr/data/degradations.py
new file mode 100644
index 0000000000000000000000000000000000000000..14319605d73149bb0b0cfe86294a89a102a9dac2
--- /dev/null
+++ b/basicsr/data/degradations.py
@@ -0,0 +1,764 @@
+import cv2
+import math
+import numpy as np
+import random
+import torch
+from scipy import special
+from scipy.stats import multivariate_normal
+from torchvision.transforms.functional import rgb_to_grayscale
+
+# -------------------------------------------------------------------- #
+# --------------------------- blur kernels --------------------------- #
+# -------------------------------------------------------------------- #
+
+
+# --------------------------- util functions --------------------------- #
+def sigma_matrix2(sig_x, sig_y, theta):
+ """Calculate the rotated sigma matrix (two dimensional matrix).
+
+ Args:
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+
+ Returns:
+ ndarray: Rotated sigma matrix.
+ """
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
+
+
+def mesh_grid(kernel_size):
+ """Generate the mesh grid, centering at zero.
+
+ Args:
+ kernel_size (int):
+
+ Returns:
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
+ xx (ndarray): with the shape (kernel_size, kernel_size)
+ yy (ndarray): with the shape (kernel_size, kernel_size)
+ """
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+ xx, yy = np.meshgrid(ax, ax)
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
+ 1))).reshape(kernel_size, kernel_size, 2)
+ return xy, xx, yy
+
+
+def pdf2(sigma_matrix, grid):
+ """Calculate PDF of the bivariate Gaussian distribution.
+
+ Args:
+ sigma_matrix (ndarray): with the shape (2, 2)
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+
+ Returns:
+ kernel (ndarrray): un-normalized kernel.
+ """
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
+ return kernel
+
+
+def cdf2(d_matrix, grid):
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
+ Used in skewed Gaussian distribution.
+
+ Args:
+ d_matrix (ndarrasy): skew matrix.
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+
+ Returns:
+ cdf (ndarray): skewed cdf.
+ """
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
+ grid = np.dot(grid, d_matrix)
+ cdf = rv.cdf(grid)
+ return cdf
+
+
+def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ isotropic (bool):
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ kernel = pdf2(sigma_matrix, grid)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
+ """Generate a bivariate generalized Gaussian kernel.
+
+ ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
+ """Generate a plateau-like anisotropic kernel.
+
+ 1 / (1+x^(beta))
+
+ Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=None,
+ isotropic=True):
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_generalized_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ isotropic=True):
+ """Randomly generate bivariate generalized Gaussian kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ # assume beta_range[0] < 1 < beta_range[1]
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_plateau(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ isotropic=True):
+ """Randomly generate bivariate plateau kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi/2, math.pi/2]
+ beta_range (tuple): [1, 4]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ # TODO: this may be not proper
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+
+ return kernel
+
+
+def random_mixed_kernels(kernel_list,
+ kernel_prob,
+ kernel_size=21,
+ sigma_x_range=(0.6, 5),
+ sigma_y_range=(0.6, 5),
+ rotation_range=(-math.pi, math.pi),
+ betag_range=(0.5, 8),
+ betap_range=(0.5, 8),
+ noise_range=None):
+ """Randomly generate mixed kernels.
+
+ Args:
+ kernel_list (tuple): a list name of kernel types,
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
+ 'plateau_aniso']
+ kernel_prob (tuple): corresponding kernel probability for each
+ kernel type
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
+ if kernel_type == 'iso':
+ kernel = random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
+ elif kernel_type == 'aniso':
+ kernel = random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
+ elif kernel_type == 'generalized_iso':
+ kernel = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ betag_range,
+ noise_range=noise_range,
+ isotropic=True)
+ elif kernel_type == 'generalized_aniso':
+ kernel = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ betag_range,
+ noise_range=noise_range,
+ isotropic=False)
+ elif kernel_type == 'plateau_iso':
+ kernel = random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
+ elif kernel_type == 'plateau_aniso':
+ kernel = random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
+ return kernel
+
+
+np.seterr(divide='ignore', invalid='ignore')
+
+
+def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
+ """2D sinc filter
+
+ Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
+
+ Args:
+ cutoff (float): cutoff frequency in radians (pi is max)
+ kernel_size (int): horizontal and vertical size, must be odd.
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ kernel = np.fromfunction(
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
+ kernel = kernel / np.sum(kernel)
+ if pad_to > kernel_size:
+ pad_size = (pad_to - kernel_size) // 2
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+ return kernel
+
+
+# ------------------------------------------------------------- #
+# --------------------------- noise --------------------------- #
+# ------------------------------------------------------------- #
+
+# ----------------------- Gaussian Noise ----------------------- #
+
+
+def generate_gaussian_noise(img, sigma=10, gray_noise=False):
+ """Generate Gaussian noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ sigma (float): Noise scale (measured in range 255). Default: 10.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ if gray_noise:
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
+ else:
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
+ return noise
+
+
+def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
+ """Add Gaussian noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ sigma (float): Noise scale (measured in range 255). Default: 10.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
+ """Add Gaussian noise (PyTorch version).
+
+ Args:
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
+ scale (float | Tensor): Noise scale. Default: 1.0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ b, _, h, w = img.size()
+ if not isinstance(sigma, (float, int)):
+ sigma = sigma.view(img.size(0), 1, 1, 1)
+ if isinstance(gray_noise, (float, int)):
+ cal_gray_noise = gray_noise > 0
+ else:
+ gray_noise = gray_noise.view(b, 1, 1, 1)
+ cal_gray_noise = torch.sum(gray_noise) > 0
+
+ if cal_gray_noise:
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
+ noise_gray = noise_gray.view(b, 1, h, w)
+
+ # always calculate color noise
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
+
+ if cal_gray_noise:
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
+ return noise
+
+
+def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
+ """Add Gaussian noise (PyTorch version).
+
+ Args:
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
+ scale (float | Tensor): Noise scale. Default: 1.0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ----------------------- Random Gaussian Noise ----------------------- #
+def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
+ if np.random.uniform() < gray_prob:
+ gray_noise = True
+ else:
+ gray_noise = False
+ return generate_gaussian_noise(img, sigma, gray_noise)
+
+
+def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
+ sigma = torch.rand(
+ img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
+ gray_noise = (gray_noise < gray_prob).float()
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
+
+
+def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ----------------------- Poisson (Shot) Noise ----------------------- #
+
+
+def generate_poisson_noise(img, scale=1.0, gray_noise=False):
+ """Generate poisson noise.
+
+ Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ scale (float): Noise scale. Default: 1.0.
+ gray_noise (bool): Whether generate gray noise. Default: False.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ if gray_noise:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # round and clip image for counting vals correctly
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = len(np.unique(img))
+ vals = 2**np.ceil(np.log2(vals))
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
+ noise = out - img
+ if gray_noise:
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
+ return noise * scale
+
+
+def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
+ """Add poisson noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ scale (float): Noise scale. Default: 1.0.
+ gray_noise (bool): Whether generate gray noise. Default: False.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ noise = generate_poisson_noise(img, scale, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
+ """Generate a batch of poisson noise (PyTorch version)
+
+ Args:
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
+ Default: 1.0.
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
+ 0 for False, 1 for True. Default: 0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ b, _, h, w = img.size()
+ if isinstance(gray_noise, (float, int)):
+ cal_gray_noise = gray_noise > 0
+ else:
+ gray_noise = gray_noise.view(b, 1, 1, 1)
+ cal_gray_noise = torch.sum(gray_noise) > 0
+ if cal_gray_noise:
+ img_gray = rgb_to_grayscale(img, num_output_channels=1)
+ # round and clip image for counting vals correctly
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
+ # use for-loop to get the unique values for each sample
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
+ out = torch.poisson(img_gray * vals) / vals
+ noise_gray = out - img_gray
+ noise_gray = noise_gray.expand(b, 3, h, w)
+
+ # always calculate color noise
+ # round and clip image for counting vals correctly
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
+ # use for-loop to get the unique values for each sample
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
+ out = torch.poisson(img * vals) / vals
+ noise = out - img
+ if cal_gray_noise:
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
+ if not isinstance(scale, (float, int)):
+ scale = scale.view(b, 1, 1, 1)
+ return noise * scale
+
+
+def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
+ """Add poisson noise to a batch of images (PyTorch version).
+
+ Args:
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
+ Default: 1.0.
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
+ 0 for False, 1 for True. Default: 0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ----------------------- Random Poisson (Shot) Noise ----------------------- #
+
+
+def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
+ scale = np.random.uniform(scale_range[0], scale_range[1])
+ if np.random.uniform() < gray_prob:
+ gray_noise = True
+ else:
+ gray_noise = False
+ return generate_poisson_noise(img, scale, gray_noise)
+
+
+def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
+ scale = torch.rand(
+ img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
+ gray_noise = (gray_noise < gray_prob).float()
+ return generate_poisson_noise_pt(img, scale, gray_noise)
+
+
+def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ------------------------------------------------------------------------ #
+# --------------------------- JPEG compression --------------------------- #
+# ------------------------------------------------------------------------ #
+
+
+def add_jpg_compression(img, quality=90):
+ """Add JPG compression artifacts.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
+ best quality. Default: 90.
+
+ Returns:
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ img = np.clip(img, 0, 1)
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
+ return img
+
+
+def random_add_jpg_compression(img, quality_range=(90, 100)):
+ """Randomly add JPG compression artifacts.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ quality_range (tuple[float] | list[float]): JPG compression quality
+ range. 0 for lowest quality, 100 for best quality.
+ Default: (90, 100).
+
+ Returns:
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ quality = np.random.uniform(quality_range[0], quality_range[1])
+ return add_jpg_compression(img, quality)
diff --git a/basicsr/data/gaussian_kernels.py b/basicsr/data/gaussian_kernels.py
new file mode 100755
index 0000000000000000000000000000000000000000..0ce57f0ae52bb4efce9212dd09960ac9c7358c3a
--- /dev/null
+++ b/basicsr/data/gaussian_kernels.py
@@ -0,0 +1,690 @@
+import math
+import numpy as np
+import random
+from scipy.ndimage.interpolation import shift
+from scipy.stats import multivariate_normal
+
+
+def sigma_matrix2(sig_x, sig_y, theta):
+ """Calculate the rotated sigma matrix (two dimensional matrix).
+ Args:
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ Returns:
+ ndarray: Rotated sigma matrix.
+ """
+ D = np.array([[sig_x**2, 0], [0, sig_y**2]])
+ U = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ return np.dot(U, np.dot(D, U.T))
+
+
+def mesh_grid(kernel_size):
+ """Generate the mesh grid, centering at zero.
+ Args:
+ kernel_size (int):
+ Returns:
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
+ xx (ndarray): with the shape (kernel_size, kernel_size)
+ yy (ndarray): with the shape (kernel_size, kernel_size)
+ """
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+ xx, yy = np.meshgrid(ax, ax)
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
+ yy.reshape(kernel_size * kernel_size,
+ 1))).reshape(kernel_size, kernel_size, 2)
+ return xy, xx, yy
+
+
+def pdf2(sigma_matrix, grid):
+ """Calculate PDF of the bivariate Gaussian distribution.
+ Args:
+ sigma_matrix (ndarray): with the shape (2, 2)
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+ Returns:
+ kernel (ndarrray): un-normalized kernel.
+ """
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
+ return kernel
+
+
+def cdf2(D, grid):
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
+ Used in skewed Gaussian distribution.
+ Args:
+ D (ndarrasy): skew matrix.
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+ Returns:
+ cdf (ndarray): skewed cdf.
+ """
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
+ grid = np.dot(grid, D)
+ cdf = rv.cdf(grid)
+ return cdf
+
+
+def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
+ """Generate a bivariate skew Gaussian kernel.
+ Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ D (ndarrasy): skew matrix.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ .. _A multivariate skew normal distribution:
+ https://www.sciencedirect.com/science/article/pii/S0047259X03001313
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ pdf = pdf2(sigma_matrix, grid)
+ cdf = cdf2(D, grid)
+ kernel = pdf * cdf
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def mass_center_shift(kernel_size, kernel):
+ """Calculate the shift of the mass center of a kenrel.
+ Args:
+ kernel_size (int):
+ kernel (ndarray): normalized kernel.
+ Returns:
+ delta_h (float):
+ delta_w (float):
+ """
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+ col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
+ delta_h = np.dot(row_sum, ax)
+ delta_w = np.dot(col_sum, ax)
+ return delta_h, delta_w
+
+
+def bivariate_skew_Gaussian_center(kernel_size,
+ sig_x,
+ sig_y,
+ theta,
+ D,
+ grid=None):
+ """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ D (ndarrasy): skew matrix.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): centered and normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
+ kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_anisotropic_Gaussian(kernel_size,
+ sig_x,
+ sig_y,
+ theta,
+ grid=None):
+ """Generate a bivariate anisotropic Gaussian kernel.
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ kernel = pdf2(sigma_matrix, grid)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
+ """Generate a bivariate isotropic Gaussian kernel.
+ Args:
+ kernel_size (int):
+ sig (float):
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
+ kernel = pdf2(sigma_matrix, grid)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_generalized_Gaussian(kernel_size,
+ sig_x,
+ sig_y,
+ theta,
+ beta,
+ grid=None):
+ """Generate a bivariate generalized Gaussian kernel.
+ Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
+ by Pascal et. al (2013).
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
+ https://arxiv.org/abs/1302.6498
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(
+ -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
+ """Generate a plateau-like anisotropic kernel.
+ 1 / (1+x^(beta))
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.reciprocal(
+ np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
+ """Generate a plateau-like isotropic kernel.
+ 1 / (1+x^(beta))
+ Args:
+ kernel_size (int):
+ sig (float):
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.reciprocal(
+ np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_skew_Gaussian_center(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate skew Gaussian kernels at center.
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ if strict:
+ sigma_max = np.max([sigma_x, sigma_y])
+ sigma_min = np.min([sigma_x, sigma_y])
+ sigma_x, sigma_y = sigma_max, sigma_min
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+
+ sigma_max = np.max([sigma_x, sigma_y])
+ thres = 3 / sigma_max
+ D = [[np.random.uniform(-thres, thres),
+ np.random.uniform(-thres, thres)],
+ [np.random.uniform(-thres, thres),
+ np.random.uniform(-thres, thres)]]
+
+ kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
+ rotation, D)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma_x, sigma_y, rotation, D
+ else:
+ return kernel
+
+
+def random_bivariate_anisotropic_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate anisotropic Gaussian kernels.
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ if strict:
+ sigma_max = np.max([sigma_x, sigma_y])
+ sigma_min = np.min([sigma_x, sigma_y])
+ sigma_x, sigma_y = sigma_max, sigma_min
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+
+ kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
+ rotation)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma_x, sigma_y, rotation
+ else:
+ return kernel
+
+
+def random_bivariate_isotropic_Gaussian(kernel_size,
+ sigma_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate isotropic Gaussian kernels.
+ Args:
+ kernel_size (int):
+ sigma_range (tuple): [0.6, 5]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
+
+ kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma
+ else:
+ return kernel
+
+
+def random_bivariate_generalized_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate generalized Gaussian kernels.
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ if strict:
+ sigma_max = np.max([sigma_x, sigma_y])
+ sigma_min = np.min([sigma_x, sigma_y])
+ sigma_x, sigma_y = sigma_max, sigma_min
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
+ rotation, beta)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma_x, sigma_y, rotation, beta
+ else:
+ return kernel
+
+
+def random_bivariate_plateau_type1(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate plateau type1 kernels.
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi/2, math.pi/2]
+ beta_range (tuple): [1, 4]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ if strict:
+ sigma_max = np.max([sigma_x, sigma_y])
+ sigma_min = np.min([sigma_x, sigma_y])
+ sigma_x, sigma_y = sigma_max, sigma_min
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
+ beta)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma_x, sigma_y, rotation, beta
+ else:
+ return kernel
+
+
+def random_bivariate_plateau_type1_iso(kernel_size,
+ sigma_range,
+ beta_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate plateau type1 kernels (iso).
+ Args:
+ kernel_size (int):
+ sigma_range (tuple): [0.6, 5]
+ beta_range (tuple): [1, 4]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
+ beta = np.random.uniform(beta_range[0], beta_range[1])
+
+ kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma, beta
+ else:
+ return kernel
+
+
+def random_mixed_kernels(kernel_list,
+ kernel_prob,
+ kernel_size=21,
+ sigma_x_range=[0.6, 5],
+ sigma_y_range=[0.6, 5],
+ rotation_range=[-math.pi, math.pi],
+ beta_range=[0.5, 8],
+ noise_range=None):
+ """Randomly generate mixed kernels.
+ Args:
+ kernel_list (tuple): a list name of kenrel types,
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
+ kernel_prob (tuple): corresponding kernel probability for each kernel type
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
+ if kernel_type == 'iso':
+ kernel = random_bivariate_isotropic_Gaussian(
+ kernel_size, sigma_x_range, noise_range=noise_range)
+ elif kernel_type == 'aniso':
+ kernel = random_bivariate_anisotropic_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=noise_range)
+ elif kernel_type == 'skew':
+ kernel = random_bivariate_skew_Gaussian_center(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=noise_range)
+ elif kernel_type == 'generalized':
+ kernel = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=noise_range)
+ elif kernel_type == 'plateau_iso':
+ kernel = random_bivariate_plateau_type1_iso(
+ kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
+ elif kernel_type == 'plateau_aniso':
+ kernel = random_bivariate_plateau_type1(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=noise_range)
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def show_one_kernel():
+ import matplotlib.pyplot as plt
+ kernel_size = 21
+
+ # bivariate skew Gaussian
+ D = [[0, 0], [0, 0]]
+ D = [[3 / 4, 0], [0, 0.5]]
+ kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
+ # bivariate anisotropic Gaussian
+ kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
+ # bivariate anisotropic Gaussian
+ kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
+ # bivariate generalized Gaussian
+ kernel = bivariate_generalized_Gaussian(
+ kernel_size, 2, 4, -math.pi / 4, beta=4)
+
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
+ print(delta_h, delta_w)
+
+ fig, axs = plt.subplots(nrows=2, ncols=2)
+ # axs.set_axis_off()
+ ax = axs[0][0]
+ im = ax.matshow(kernel, cmap='jet', origin='upper')
+ fig.colorbar(im, ax=ax)
+
+ # image
+ ax = axs[0][1]
+ kernel_vis = kernel - np.min(kernel)
+ kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
+ ax.imshow(kernel_vis, interpolation='nearest')
+
+ _, xx, yy = mesh_grid(kernel_size)
+ # contour
+ ax = axs[1][0]
+ CS = ax.contour(xx, yy, kernel, origin='upper')
+ ax.clabel(CS, inline=1, fontsize=3)
+
+ # contourf
+ ax = axs[1][1]
+ kernel = kernel / np.max(kernel)
+ p = ax.contourf(
+ xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
+ fig.colorbar(p)
+
+ plt.show()
+
+
+def show_plateau_kernel():
+ import matplotlib.pyplot as plt
+ kernel_size = 21
+
+ kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
+ kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
+ kernel_gau = bivariate_generalized_Gaussian(
+ kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
+ print(delta_h, delta_w)
+
+ # kernel_slice = kernel[10, :]
+ # kernel_gau_slice = kernel_gau[10, :]
+ # kernel_norm_slice = kernel_norm[10, :]
+ # fig, ax = plt.subplots()
+ # t = list(range(1, 22))
+
+ # ax.plot(t, kernel_gau_slice)
+ # ax.plot(t, kernel_slice)
+ # ax.plot(t, kernel_norm_slice)
+
+ # t = np.arange(0, 10, 0.1)
+ # y = np.exp(-0.5 * t)
+ # y2 = np.reciprocal(1 + t)
+ # print(t.shape)
+ # print(y.shape)
+ # ax.plot(t, y)
+ # ax.plot(t, y2)
+ # plt.show()
+
+ fig, axs = plt.subplots(nrows=2, ncols=2)
+ # axs.set_axis_off()
+ ax = axs[0][0]
+ im = ax.matshow(kernel, cmap='jet', origin='upper')
+ fig.colorbar(im, ax=ax)
+
+ # image
+ ax = axs[0][1]
+ kernel_vis = kernel - np.min(kernel)
+ kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
+ ax.imshow(kernel_vis, interpolation='nearest')
+
+ _, xx, yy = mesh_grid(kernel_size)
+ # contour
+ ax = axs[1][0]
+ CS = ax.contour(xx, yy, kernel, origin='upper')
+ ax.clabel(CS, inline=1, fontsize=3)
+
+ # contourf
+ ax = axs[1][1]
+ kernel = kernel / np.max(kernel)
+ p = ax.contourf(
+ xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
+ fig.colorbar(p)
+
+ plt.show()
diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6a6c07b1af76412bf862d57453a768432634f09
--- /dev/null
+++ b/basicsr/data/paired_image_dataset.py
@@ -0,0 +1,101 @@
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class PairedImageDataset(data.Dataset):
+ """Paired image dataset for image restoration.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
+ GT image pairs.
+
+ There are three modes:
+ 1. 'lmdb': Use lmdb files.
+ If opt['io_backend'] == lmdb.
+ 2. 'meta_info_file': Use meta information file to generate paths.
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
+ 3. 'folder': Scan folders to generate paths.
+ The rest.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info_file (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Default: '{}'.
+ gt_size (int): Cropped patched size for gt patches.
+ use_flip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+
+ scale (bool): Scale, which will be added automatically.
+ phase (str): 'train' or 'val'.
+ """
+
+ def __init__(self, opt):
+ super(PairedImageDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+ if 'filename_tmpl' in opt:
+ self.filename_tmpl = opt['filename_tmpl']
+ else:
+ self.filename_tmpl = '{}'
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
+ self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
+ self.opt['meta_info_file'], self.filename_tmpl)
+ else:
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ gt_path = self.paths[index]['gt_path']
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ lq_path = self.paths[index]['lq_path']
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ # augmentation for training
+ if self.opt['phase'] == 'train':
+ gt_size = self.opt['gt_size']
+ # random crop
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+ # flip, rotation
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
+
+ # TODO: color space transform
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0
--- /dev/null
+++ b/basicsr/data/prefetch_dataloader.py
@@ -0,0 +1,125 @@
+import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+ """A general prefetch generator.
+
+ Ref:
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+ Args:
+ generator: Python generator.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, generator, num_prefetch_queue):
+ threading.Thread.__init__(self)
+ self.queue = Queue.Queue(num_prefetch_queue)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def __next__(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class PrefetchDataLoader(DataLoader):
+ """Prefetch version of dataloader.
+
+ Ref:
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+ TODO:
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+ ddp.
+
+ Args:
+ num_prefetch_queue (int): Number of prefetch queue.
+ kwargs (dict): Other arguments for dataloader.
+ """
+
+ def __init__(self, num_prefetch_queue, **kwargs):
+ self.num_prefetch_queue = num_prefetch_queue
+ super(PrefetchDataLoader, self).__init__(**kwargs)
+
+ def __iter__(self):
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+ """CPU prefetcher.
+
+ Args:
+ loader: Dataloader.
+ """
+
+ def __init__(self, loader):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+
+ def next(self):
+ try:
+ return next(self.loader)
+ except StopIteration:
+ return None
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+ """CUDA prefetcher.
+
+ Ref:
+ https://github.com/NVIDIA/apex/issues/304#
+
+ It may consums more GPU memory.
+
+ Args:
+ loader: Dataloader.
+ opt (dict): Options.
+ """
+
+ def __init__(self, loader, opt):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+ self.opt = opt
+ self.stream = torch.cuda.Stream()
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.preload()
+
+ def preload(self):
+ try:
+ self.batch = next(self.loader) # self.batch is a dict
+ except StopIteration:
+ self.batch = None
+ return None
+ # put tensors to gpu
+ with torch.cuda.stream(self.stream):
+ for k, v in self.batch.items():
+ if torch.is_tensor(v):
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ self.preload()
+ return batch
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+ self.preload()
diff --git a/basicsr/data/reds_dataset.py b/basicsr/data/reds_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fabef1d7e80866888f3b57ecfeb4d97c93bcb5cd
--- /dev/null
+++ b/basicsr/data/reds_dataset.py
@@ -0,0 +1,352 @@
+import numpy as np
+import random
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.flow_util import dequantize_flow
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class REDSDataset(data.Dataset):
+ """REDS dataset for training.
+
+ The keys are generated from a meta info txt file.
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
+
+ Each line contains:
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
+ a white space.
+ Examples:
+ 000 100 (720,1280,3)
+ 001 100 (720,1280,3)
+ ...
+
+ Key examples: "000/00000000"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ dataroot_flow (str, optional): Data root path for flow.
+ meta_info_file (str): Path for meta information file.
+ val_partition (str): Validation partition types. 'REDS4' or 'official'.
+ io_backend (dict): IO backend type and other kwarg.
+ num_frame (int): Window size for input frames.
+ gt_size (int): Cropped patched size for gt patches.
+ interval_list (list): Interval list for temporal augmentation.
+ random_reverse (bool): Random reverse input frames.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __init__(self, opt):
+ super(REDSDataset, self).__init__()
+ self.opt = opt
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+ self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
+ assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
+ self.num_frame = opt['num_frame']
+ self.num_half_frames = opt['num_frame'] // 2
+
+ self.keys = []
+ with open(opt['meta_info_file'], 'r') as fin:
+ for line in fin:
+ folder, frame_num, _ = line.split(' ')
+ self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
+
+ # remove the video clips used in validation
+ if opt['val_partition'] == 'REDS4':
+ val_partition = ['000', '011', '015', '020']
+ elif opt['val_partition'] == 'official':
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
+ else:
+ raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
+ f"Supported ones are ['official', 'REDS4'].")
+ self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ if self.flow_root is not None:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
+ else:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+ # temporal augmentation configs
+ self.interval_list = opt['interval_list']
+ self.random_reverse = opt['random_reverse']
+ interval_str = ','.join(str(x) for x in opt['interval_list'])
+ logger = get_root_logger()
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+ f'random reverse is {self.random_reverse}.')
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
+ center_frame_idx = int(frame_name)
+
+ # determine the neighboring frames
+ interval = random.choice(self.interval_list)
+
+ # ensure not exceeding the borders
+ start_frame_idx = center_frame_idx - self.num_half_frames * interval
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
+ # each clip has 100 frames starting from 0 to 99
+ while (start_frame_idx < 0) or (end_frame_idx > 99):
+ center_frame_idx = random.randint(0, 99)
+ start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
+ frame_name = f'{center_frame_idx:08d}'
+ neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ neighbor_list.reverse()
+
+ assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
+
+ # get the GT frame (as the center frame)
+ if self.is_lmdb:
+ img_gt_path = f'{clip_name}/{frame_name}'
+ else:
+ img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # get the neighboring LQ frames
+ img_lqs = []
+ for neighbor in neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
+ else:
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ img_lqs.append(img_lq)
+
+ # get flows
+ if self.flow_root is not None:
+ img_flows = []
+ # read previous flows
+ for i in range(self.num_half_frames, 0, -1):
+ if self.is_lmdb:
+ flow_path = f'{clip_name}/{frame_name}_p{i}'
+ else:
+ flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
+ img_bytes = self.file_client.get(flow_path, 'flow')
+ cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
+ dx, dy = np.split(cat_flow, 2, axis=0)
+ flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
+ img_flows.append(flow)
+ # read next flows
+ for i in range(1, self.num_half_frames + 1):
+ if self.is_lmdb:
+ flow_path = f'{clip_name}/{frame_name}_n{i}'
+ else:
+ flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
+ img_bytes = self.file_client.get(flow_path, 'flow')
+ cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
+ dx, dy = np.split(cat_flow, 2, axis=0)
+ flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
+ img_flows.append(flow)
+
+ # for random crop, here, img_flows and img_lqs have the same
+ # spatial size
+ img_lqs.extend(img_flows)
+
+ # randomly crop
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
+ if self.flow_root is not None:
+ img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
+
+ # augmentation - flip, rotate
+ img_lqs.append(img_gt)
+ if self.flow_root is not None:
+ img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
+ else:
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
+ img_gt = img_results[-1]
+
+ if self.flow_root is not None:
+ img_flows = img2tensor(img_flows)
+ # add the zero center flow
+ img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
+ img_flows = torch.stack(img_flows, dim=0)
+
+ # img_lqs: (t, c, h, w)
+ # img_flows: (t, 2, h, w)
+ # img_gt: (c, h, w)
+ # key: str
+ if self.flow_root is not None:
+ return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
+ else:
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
+
+
+@DATASET_REGISTRY.register()
+class REDSRecurrentDataset(data.Dataset):
+ """REDS dataset for training recurrent networks.
+
+ The keys are generated from a meta info txt file.
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
+
+ Each line contains:
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
+ a white space.
+ Examples:
+ 000 100 (720,1280,3)
+ 001 100 (720,1280,3)
+ ...
+
+ Key examples: "000/00000000"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ dataroot_flow (str, optional): Data root path for flow.
+ meta_info_file (str): Path for meta information file.
+ val_partition (str): Validation partition types. 'REDS4' or 'official'.
+ io_backend (dict): IO backend type and other kwarg.
+ num_frame (int): Window size for input frames.
+ gt_size (int): Cropped patched size for gt patches.
+ interval_list (list): Interval list for temporal augmentation.
+ random_reverse (bool): Random reverse input frames.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __init__(self, opt):
+ super(REDSRecurrentDataset, self).__init__()
+ self.opt = opt
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+ self.num_frame = opt['num_frame']
+
+ self.keys = []
+ with open(opt['meta_info_file'], 'r') as fin:
+ for line in fin:
+ folder, frame_num, _ = line.split(' ')
+ self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
+
+ # remove the video clips used in validation
+ if opt['val_partition'] == 'REDS4':
+ val_partition = ['000', '011', '015', '020']
+ elif opt['val_partition'] == 'official':
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
+ else:
+ raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
+ f"Supported ones are ['official', 'REDS4'].")
+ if opt['test_mode']:
+ self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
+ else:
+ self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ if hasattr(self, 'flow_root') and self.flow_root is not None:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
+ else:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+ # temporal augmentation configs
+ self.interval_list = opt.get('interval_list', [1])
+ self.random_reverse = opt.get('random_reverse', False)
+ interval_str = ','.join(str(x) for x in self.interval_list)
+ logger = get_root_logger()
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+ f'random reverse is {self.random_reverse}.')
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
+
+ # determine the neighboring frames
+ interval = random.choice(self.interval_list)
+
+ # ensure not exceeding the borders
+ start_frame_idx = int(frame_name)
+ if start_frame_idx > 100 - self.num_frame * interval:
+ start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
+ end_frame_idx = start_frame_idx + self.num_frame * interval
+
+ neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
+
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ neighbor_list.reverse()
+
+ # get the neighboring LQ and GT frames
+ img_lqs = []
+ img_gts = []
+ for neighbor in neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
+ img_gt_path = f'{clip_name}/{neighbor:08d}'
+ else:
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
+ img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
+
+ # get LQ
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ img_lqs.append(img_lq)
+
+ # get GT
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ img_gts.append(img_gt)
+
+ # randomly crop
+ img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
+
+ # augmentation - flip, rotate
+ img_lqs.extend(img_gts)
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
+ img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
+
+ # img_lqs: (t, c, h, w)
+ # img_gts: (t, c, h, w)
+ # key: str
+ return {'lq': img_lqs, 'gt': img_gts, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..aead9dc73ed063e1c5865040eaa2652b26aa3ad3
--- /dev/null
+++ b/basicsr/data/transforms.py
@@ -0,0 +1,165 @@
+import cv2
+import random
+
+
+def mod_crop(img, scale):
+ """Mod crop images, used during testing.
+
+ Args:
+ img (ndarray): Input image.
+ scale (int): Scale factor.
+
+ Returns:
+ ndarray: Result image.
+ """
+ img = img.copy()
+ if img.ndim in (2, 3):
+ h, w = img.shape[0], img.shape[1]
+ h_remainder, w_remainder = h % scale, w % scale
+ img = img[:h - h_remainder, :w - w_remainder, ...]
+ else:
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
+ return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
+ """Paired random crop.
+
+ It crops lists of lq and gt images with corresponding locations.
+
+ Args:
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ gt_patch_size (int): GT patch size.
+ scale (int): Scale factor.
+ gt_path (str): Path to ground-truth.
+
+ Returns:
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
+ only have one element, just return ndarray.
+ """
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+
+ h_lq, w_lq, _ = img_lqs[0].shape
+ h_gt, w_gt, _ = img_gts[0].shape
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.')
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ return img_gts, img_lqs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+ We use vertical flip and transpose for rotation implementation.
+ All the images in the list use the same augmentation.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+ is an ndarray, it will be transformed to a list.
+ hflip (bool): Horizontal flip. Default: True.
+ rotation (bool): Ratotation. Default: True.
+ flows (list[ndarray]: Flows to be augmented. If the input is an
+ ndarray, it will be transformed to a list.
+ Dimension is (h, w, 2). Default: None.
+ return_status (bool): Return the status of flip and rotation.
+ Default: False.
+
+ Returns:
+ list[ndarray] | ndarray: Augmented images and flows. If returned
+ results only have one element, just return ndarray.
+
+ """
+ hflip = hflip and random.random() < 0.5
+ vflip = rotation and random.random() < 0.5
+ rot90 = rotation and random.random() < 0.5
+
+ def _augment(img):
+ if hflip: # horizontal
+ cv2.flip(img, 1, img)
+ if vflip: # vertical
+ cv2.flip(img, 0, img)
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip: # horizontal
+ cv2.flip(flow, 1, flow)
+ flow[:, :, 0] *= -1
+ if vflip: # vertical
+ cv2.flip(flow, 0, flow)
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ imgs = [_augment(img) for img in imgs]
+ if len(imgs) == 1:
+ imgs = imgs[0]
+
+ if flows is not None:
+ if not isinstance(flows, list):
+ flows = [flows]
+ flows = [_augment_flow(flow) for flow in flows]
+ if len(flows) == 1:
+ flows = flows[0]
+ return imgs, flows
+ else:
+ if return_status:
+ return imgs, (hflip, vflip, rot90)
+ else:
+ return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+ """Rotate image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees. Positive values mean
+ counter-clockwise rotation.
+ center (tuple[int]): Rotation center. If the center is None,
+ initialize it as the center of the image. Default: None.
+ scale (float): Isotropic scale factor. Default: 1.0.
+ """
+ (h, w) = img.shape[:2]
+
+ if center is None:
+ center = (w // 2, h // 2)
+
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
+ return rotated_img
diff --git a/basicsr/data/vfhq_real_degradation2_dataset.py b/basicsr/data/vfhq_real_degradation2_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..4c4a9a4ee605e5f8d5c068b67d6e6ee14981dd00
--- /dev/null
+++ b/basicsr/data/vfhq_real_degradation2_dataset.py
@@ -0,0 +1,411 @@
+import os
+import random
+from pathlib import Path
+
+from PIL import Image
+import cv2
+import ffmpeg
+import io
+import av
+import numpy as np
+import torch
+from torchvision.transforms.functional import normalize
+from basicsr.data.degradations import (random_add_gaussian_noise,
+ random_mixed_kernels)
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, img2tensor, imfrombytes, scandir
+from basicsr.utils.registry import DATASET_REGISTRY
+from facelib.utils.face_restoration_helper import FaceAligner
+from torch.utils import data as data
+
+
+@DATASET_REGISTRY.register()
+class SingleVFHQDataset(data.Dataset):
+ """Support for blind setting adopted in paper. We excludes the random scale compared to GFPGAN.
+
+ This dataset is adopted in BasicVSR.
+
+ The degradation order is blur+downsample+noise
+
+ Note that we skip the low quality frames within the VFHQ clip.
+ Directly read image by cv2. Generate LR images online.
+ NOTE: The specific degradation order is blur-noise-downsample-crf-upsample
+
+ The keys are generated from a meta info txt file.
+
+ Key format: subfolder-name/clip-length/frame-name
+ Key examples: "id00020#t0bbIRgKKzM#00381.txt#000.mp4/00000152/00000000"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_clip_meta_info (srt): Data root path for meta info of each gt clip.
+ global_meta_info_file (str): Path for global meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ num_frame (int): Window size for input frames.
+ interval_list (list): Interval list for temporal augmentation.
+ random_reverse (bool): Random reverse input frames.
+ use_flip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+ """
+
+ def __init__(self, opt):
+ super(SingleVFHQDataset, self).__init__()
+ self.opt = opt
+ self.gt_root = Path(opt['dataroot_gt'])
+ self.normalize = opt.get('normalize', False)
+ self.need_align = opt.get('need_align', False)
+ logger = get_root_logger()
+
+ self.keys = []
+ with open(opt['global_meta_info_file'], 'r') as fin:
+ for line in fin:
+ real_clip_path = '/'.join(line.split('/')[:-1])
+ clip_length = line.split('/')[-1]
+ clip_length = int(clip_length)
+ self.keys.extend(
+ [f'{real_clip_path}/{clip_length:08d}/{frame_idx:08d}' for frame_idx in range(int(clip_length))])
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ self.io_backend_opt['db_paths'] = [self.gt_root]
+ self.io_backend_opt['client_keys'] = ['gt']
+
+ if self.need_align:
+ self.dataroot_meta_info = opt['dataroot_meta_info']
+ self.face_aligner = FaceAligner(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ use_parse=True,)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ key = self.keys[index]
+ real_clip_path = '/'.join(key.split('/')[:-2])
+ clip_length = int(key.split('/')[-2])
+ frame_idx = int(key.split('/')[-1])
+
+ # get the neighboring GT frames
+ flag = real_clip_path.split('/')[0]
+ clip_name = real_clip_path.split('/')[-1]
+
+ paths = sorted(list(scandir(os.path.join(
+ self.gt_root, clip_name))))
+
+ assert len(paths) == clip_length, "Wrong length of frame list"
+
+ img_gt_path = os.path.join(
+ self.gt_root, clip_name, paths[frame_idx])
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # alignment
+ if self.need_align:
+ clip_info_path = os.path.join(
+ self.dataroot_meta_info, f'{clip_name}.txt')
+ clip_info = []
+ with open(clip_info_path, 'r', encoding='utf-8') as fin:
+ for line in fin:
+ line = line.strip()
+ if line.startswith('0'):
+ clip_info.append(line)
+
+ landmarks_str = clip_info[frame_idx].split(' ')[1:]
+ landmarks = np.array([float(x)
+ for x in landmarks_str]).reshape(5, 2)
+ self.face_aligner.clean_all()
+ # align and warp each face
+ img_gt = self.face_aligner.align_single_face(img_gt, landmarks)
+
+ # augmentation - flip, rotate
+ img_gt = augment(img_gt, self.opt['use_flip'], self.opt['use_rot'])
+ img_in = img_gt
+
+ # ------------- end --------------#
+ img_in, img_gt = img2tensor([img_in, img_gt])
+ if self.normalize:
+ normalize(img_in, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
+ normalize(img_gt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
+
+ # img_lqs: (t, c, h, w)
+ # img_gts: (t, c, h, w)
+ # key: str
+ return {'in': img_in, 'gt': img_gt, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
+
+@DATASET_REGISTRY.register()
+class VFHQDataset(data.Dataset):
+ """Support for blind setting adopted in paper. We excludes the random scale compared to GFPGAN.
+
+ This dataset is adopted in BasicVSR.
+
+ The degradation order is blur+downsample+noise
+
+ Note that we skip the low quality frames within the VFHQ clip.
+ Directly read image by cv2. Generate LR images online.
+ NOTE: The specific degradation order is blur-noise-downsample-crf-upsample
+
+ The keys are generated from a meta info txt file.
+
+ Key format: subfolder-name/clip-length/frame-name
+ Key examples: "id00020#t0bbIRgKKzM#00381.txt#000.mp4/00000152/00000000"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_clip_meta_info (srt): Data root path for meta info of each gt clip.
+ global_meta_info_file (str): Path for global meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ num_frame (int): Window size for input frames.
+ interval_list (list): Interval list for temporal augmentation.
+ random_reverse (bool): Random reverse input frames.
+ use_flip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+ """
+
+ def __init__(self, opt):
+ super(VFHQDataset, self).__init__()
+ self.opt = opt
+ self.gt_root = Path(opt['dataroot_gt'])
+
+ self.num_frame = opt['num_frame']
+ self.scale = opt['scale']
+ self.need_align = opt.get('need_align', False)
+ self.normalize = opt.get('normalize', False)
+
+ self.keys = []
+ with open(opt['global_meta_info_file'], 'r') as fin:
+ for line in fin:
+ real_clip_path = '/'.join(line.split('/')[:-1])
+ clip_length = line.split('/')[-1]
+ clip_length = int(clip_length)
+ self.keys.extend(
+ [f'{real_clip_path}/{clip_length:08d}/{frame_idx:08d}' for frame_idx in range(int(clip_length))])
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ self.io_backend_opt['db_paths'] = [self.gt_root]
+ self.io_backend_opt['client_keys'] = ['gt']
+
+ # temporal augmentation configs
+ self.interval_list = opt['interval_list']
+ self.random_reverse = opt['random_reverse']
+ interval_str = ','.join(str(x) for x in opt['interval_list'])
+ logger = get_root_logger()
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+ f'random reverse is {self.random_reverse}.')
+
+ # degradations
+ # blur
+ self.blur_kernel_size = opt['blur_kernel_size']
+ self.kernel_list = opt['kernel_list']
+ self.kernel_prob = opt['kernel_prob']
+ self.blur_x_sigma = opt['blur_x_sigma']
+ self.blur_y_sigma = opt['blur_y_sigma']
+ # noise
+ self.noise_range = opt['noise_range']
+ # resize
+ self.resize_prob = opt['resize_prob']
+ # crf
+ self.crf_range = opt['crf_range']
+ # codec
+ self.vcodec = opt['vcodec']
+ self.vcodec_prob = opt['vcodec_prob']
+
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
+ f'x_sigma: [{", ".join(map(str, self.blur_x_sigma))}], '
+ f'y_sigma: [{", ".join(map(str, self.blur_y_sigma))}], ')
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
+ logger.info(
+ f'CRF compression: [{", ".join(map(str, self.crf_range))}]')
+ logger.info(f'Codec: [{", ".join(map(str, self.vcodec))}]')
+
+ if self.need_align:
+ self.dataroot_meta_info = opt['dataroot_meta_info']
+ self.face_aligner = FaceAligner(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ use_parse=True,)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(
+ self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ key = self.keys[index]
+ real_clip_path = '/'.join(key.split('/')[:-2])
+ clip_length = int(key.split('/')[-2])
+ frame_idx = int(key.split('/')[-1])
+ clip_name = real_clip_path.split('/')[-1]
+
+ paths = sorted(list(scandir(os.path.join(
+ self.gt_root, clip_name))))
+
+ # determine the neighboring frames
+ interval = random.choice(self.interval_list)
+
+ # exceed the length, re-select a new clip
+ while (clip_length - self.num_frame * interval) < 0:
+ interval = random.choice(self.interval_list)
+
+ # ensure not exceeding the borders
+ # print(self.num_frame, type(self.num_frame))
+ # print(interval, type(interval))
+ start_frame_idx = frame_idx - self.num_frame // 2 * interval
+ end_frame_idx = frame_idx + self.num_frame // 2 * interval
+
+ # flag = (start_frame_idx < 0) or (end_frame_idx > clip_length)
+ # print(key, start_frame_idx, end_frame_idx, interval, flag)
+ # each clip has 100+ frames
+ while (start_frame_idx < 0) or (end_frame_idx > clip_length):
+ frame_idx = random.randint(self.num_frame//2 * interval,
+ clip_length - self.num_frame//2 * interval)
+ start_frame_idx = frame_idx - self.num_frame // 2 * interval
+ end_frame_idx = frame_idx + self.num_frame // 2 * interval
+ neighbor_list = list(
+ range(start_frame_idx, end_frame_idx, interval))
+ # print(start_frame_idx, end_frame_idx, frame_idx, interval)
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ neighbor_list.reverse()
+
+ assert len(neighbor_list) == self.num_frame, (
+ f'Wrong length of neighbor list: {len(neighbor_list)}')
+
+ # get the neighboring GT frames
+ img_gts = []
+
+ if self.need_align:
+ clip_info_path = os.path.join(
+ self.dataroot_meta_info, f'{clip_name}.txt')
+ clip_info = []
+ with open(clip_info_path, 'r', encoding='utf-8') as fin:
+ for line in fin:
+ line = line.strip()
+ if line.startswith('0'):
+ clip_info.append(line)
+
+ for neighbor in neighbor_list:
+ assert paths[neighbor] == clip_info[neighbor].split(' ')[0], \
+ f'{clip_name}: Mismatch frame {paths[neighbor]} and {clip_info[neighbor]}'
+ # img_gt_path = os.path.join(
+ # self.gt_root, clip_name, f'{neighbor:08d}.png')
+ img_gt_path = os.path.join(
+ self.gt_root, clip_name, paths[neighbor])
+ # img_bytes = self.file_client.get(img_gt_path, 'gt')
+ # img_gt = imfrombytes(img_bytes, float32=True)
+ # img_gt = cv2.imread(img_gt_path) / 255.0
+ img_gt = np.asarray(Image.open(img_gt_path))[:, :, ::-1] / 255.0
+ img_gts.append(img_gt)
+
+ # augmentation - flip, rotate
+ img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot'])
+
+ # ------------- generate LQ frames --------------#
+ # add blur
+ kernel = random_mixed_kernels(self.kernel_list, self.kernel_prob, self.blur_kernel_size, self.blur_x_sigma,
+ self.blur_y_sigma)
+ img_lqs = [cv2.filter2D(v, -1, kernel) for v in img_gts]
+ # add noise
+ img_lqs = [
+ random_add_gaussian_noise(v, self.noise_range, gray_prob=0.5, clip=True, rounds=False) for v in img_lqs
+ ]
+ # downsample
+ original_height, original_width = img_gts[0].shape[0:2]
+ resize_type = random.choices(
+ [cv2.INTER_AREA, cv2.INTER_LINEAR, cv2.INTER_CUBIC], self.resize_prob)[0]
+ resized_height, resized_width = int(
+ original_height // self.scale), int(original_width // self.scale)
+ # ensure the resized_height and resized_width are even numbers
+ img_lqs = [cv2.resize(v, (resized_width, resized_height),
+ interpolation=resize_type) for v in img_lqs]
+ # add noise
+ img_lqs = [
+ random_add_gaussian_noise(v, self.noise_range, gray_prob=0.5, clip=True, rounds=False) for v in img_lqs
+ ]
+
+ # ffmpeg
+ crf = np.random.randint(self.crf_range[0], self.crf_range[1])
+ codec = random.choices(self.vcodec, self.vcodec_prob)[0]
+
+ buf = io.BytesIO()
+ with av.open(buf, 'w', 'mp4') as container:
+ stream = container.add_stream(codec, rate=1)
+ stream.height = resized_height
+ stream.width = resized_width
+ stream.pix_fmt = 'yuv420p'
+ stream.options = {'crf': str(crf)}
+
+ for img_lq in img_lqs:
+ img_lq = np.clip(img_lq * 255, 0, 255).astype(np.uint8)
+ frame = av.VideoFrame.from_ndarray(img_lq, format='rgb24')
+ frame.pict_type = 0 # Changed from 'NONE' to 0
+ for packet in stream.encode(frame):
+ container.mux(packet)
+
+ # Flush stream
+ for packet in stream.encode():
+ container.mux(packet)
+
+ img_lqs = []
+ with av.open(buf, 'r', 'mp4') as container:
+ if container.streams.video:
+ for frame in container.decode(**{'video': 0}):
+ img_lqs.append(frame.to_rgb().to_ndarray() / 255.)
+
+ assert len(img_lqs) == len(img_gts), 'Wrong length'
+ # ------------ Align -------------#
+ if self.need_align:
+ align_lqs, align_gts = [], []
+ for frame_idx, (img_lq, img_gt) in enumerate(zip(img_lqs, img_gts)):
+ landmarks_str = clip_info[frame_idx].split(' ')[1:]
+ # print(clip_name, paths[neighbor], landmarks_str)
+ landmarks = np.array([float(x)
+ for x in landmarks_str]).reshape(5, 2)
+ self.face_aligner.clean_all()
+ # align and warp each face
+ img_lq, img_gt = self.face_aligner.align_pair_face(
+ img_lq, img_gt, landmarks)
+ align_lqs.append(img_lq)
+ align_gts.append(img_gt)
+ img_lqs, img_gts = align_lqs, align_gts
+
+ # ------------- end --------------#
+ img_gts = img2tensor(img_gts)
+ img_lqs = img2tensor(img_lqs)
+ img_gts = torch.stack(img_gts, dim=0)
+ img_lqs = torch.stack(img_lqs, dim=0)
+
+ if self.normalize:
+ normalize(img_lqs, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
+ normalize(img_gts, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
+
+ # img_lqs: (t, c, h, w)
+ # img_gts: (t, c, h, w)
+ # key: str
+ return {'lq': img_lqs, 'gt': img_gts, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
+
\ No newline at end of file
diff --git a/basicsr/data/video_test_dataset.py b/basicsr/data/video_test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f5688f724db356b8033d57f5d666549f9945179
--- /dev/null
+++ b/basicsr/data/video_test_dataset.py
@@ -0,0 +1,232 @@
+import glob
+import torch
+import os
+from os import path as osp
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+from collections import defaultdict
+import numpy as np
+import cv2
+
+from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import DATASET_REGISTRY
+from basicsr.utils.img_util import img2tensor, tensor2img
+from facelib.utils.face_restoration_helper import FaceAligner
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDataset(data.Dataset):
+ """Video test dataset.
+
+ Supported datasets: Vid4, REDS4, REDSofficial.
+ More generally, it supports testing dataset with following structures:
+
+ ::
+
+ dataroot
+ ├── subfolder1
+ ├── frame000
+ ├── frame001
+ ├── ...
+ ├── subfolder2
+ ├── frame000
+ ├── frame001
+ ├── ...
+ ├── ...
+
+ For testing datasets, there is no need to prepare LMDB files.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ io_backend (dict): IO backend type and other kwarg.
+ cache_data (bool): Whether to cache testing datasets.
+ name (str): Dataset name.
+ global_meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
+ in the dataroot will be used.
+ num_frame (int): Window size for input frames.
+ padding (str): Padding mode.
+ """
+
+ def __init__(self, opt):
+ super(VideoTestDataset, self).__init__()
+ self.opt = opt
+ self.cache_data = opt['cache_data']
+ self.interval = opt['interval']
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+ self.data_info = {'lq_path': [], 'gt_path': [],
+ 'folder': [], 'idx': [], 'border': []}
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+ logger = get_root_logger()
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+ self.imgs_lq, self.imgs_gt = {}, {}
+ if 'global_meta_info_file' in opt:
+ with open(opt['global_meta_info_file'], 'r') as fin:
+ subfolders = [line.split('/')[0] for line in fin]
+ subfolders_lq = [osp.join(self.lq_root, key)
+ for key in subfolders]
+ subfolders_gt = [osp.join(self.gt_root, key)
+ for key in subfolders]
+ else:
+ subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
+ subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
+
+ for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
+ # get frame list for lq and gt
+ subfolder_name = osp.basename(subfolder_lq)
+ img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))[
+ ::self.interval]
+ img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))[
+ ::self.interval]
+
+ max_idx = len(img_paths_lq)
+ assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
+ f' and gt folders ({len(img_paths_gt)})')
+
+ self.data_info['lq_path'].extend(img_paths_lq)
+ self.data_info['gt_path'].extend(img_paths_gt)
+ self.data_info['folder'].extend([subfolder_name] * max_idx)
+ for i in range(max_idx):
+ self.data_info['idx'].append(f'{i}/{max_idx}')
+ border_l = [0] * max_idx
+ for i in range(self.opt['num_frame'] // 2):
+ border_l[i] = 1
+ border_l[max_idx - i - 1] = 1
+ self.data_info['border'].extend(border_l)
+
+ # cache data or save the frame list
+ if self.cache_data:
+ logger.info(
+ f'Cache {subfolder_name} for VideoTestDataset...')
+ self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
+ self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
+ else:
+ self.imgs_lq[subfolder_name] = img_paths_lq
+ self.imgs_gt[subfolder_name] = img_paths_gt
+
+ self.normalize = opt.get('normalize', False)
+
+ def __getitem__(self, index):
+ folder = self.data_info['folder'][index]
+ idx, max_idx = self.data_info['idx'][index].split('/')
+ idx, max_idx = int(idx), int(max_idx)
+ border = self.data_info['border'][index]
+ lq_path = self.data_info['lq_path'][index]
+
+ select_idx = generate_frame_indices(
+ idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+ if self.cache_data:
+ imgs_lq = self.imgs_lq[folder].index_select(
+ 0, torch.LongTensor(select_idx))
+ img_gt = self.imgs_gt[folder][idx]
+ else:
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+ imgs_lq = read_img_seq(img_paths_lq)
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]])
+ img_gt.squeeze_(0)
+
+ if self.normalize:
+ normalize(imgs_lq, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
+ normalize(img_gt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': folder, # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
+ 'border': border, # 1 for border, 0 for non-border
+ 'lq_path': lq_path # center frame
+ }
+
+ def __len__(self):
+ return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoRecurrentTestDataset(VideoTestDataset):
+ """Video test dataset for recurrent architectures, which takes LR video
+ frames as input and output corresponding HR video frames.
+
+ Args:
+ opt (dict): Same as VideoTestDataset. Unused opt:
+ padding (str): Padding mode.
+
+ """
+
+ def __init__(self, opt):
+ super(VideoRecurrentTestDataset, self).__init__(opt)
+ # Find unique folder strings
+ self.folders = sorted(list(set(self.data_info['folder'])))
+ self.need_align = opt.get('need_align', False)
+ self.normalize = opt.get('normalize', False)
+
+ if self.need_align:
+ self.dataroot_meta_info = opt['dataroot_meta_info']
+ self.face_aligner = FaceAligner(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ use_parse=True,)
+
+ def __getitem__(self, index):
+ folder = self.folders[index]
+
+ if self.cache_data:
+ imgs_lq = self.imgs_lq[folder]
+ imgs_gt = self.imgs_gt[folder]
+ else:
+ imgs_lq = read_img_seq(self.imgs_lq[folder])
+ imgs_gt = read_img_seq(self.imgs_gt[folder])
+
+ if self.need_align:
+ clip_info_path = os.path.join(
+ self.dataroot_meta_info, f'{folder}.txt')
+ clip_info = []
+ with open(clip_info_path, 'r', encoding='utf-8') as fin:
+ for line in fin:
+ line = line.strip()
+ if line.startswith('0'):
+ clip_info.append(line)
+
+ align_lqs, align_gts = [], []
+ for frame_idx, (img_lq, img_gt) in enumerate(zip(imgs_lq, imgs_gt)):
+ img_lq = tensor2img(img_lq) / 255.0
+ img_gt = tensor2img(img_gt) / 255.0
+ landmarks_str = clip_info[frame_idx].split(' ')[1:]
+ # print(clip_name, paths[neighbor], landmarks_str)
+ landmarks = np.array([float(x)
+ for x in landmarks_str]).reshape(5, 2)
+ self.face_aligner.clean_all()
+ # align and warp each face
+ img_lq, img_gt = self.face_aligner.align_pair_face(
+ img_lq, img_gt, landmarks)
+ align_lqs.append(img_lq)
+ align_gts.append(img_gt)
+ img_lqs, img_gts = align_lqs, align_gts
+
+ img_gts = img2tensor(img_gts)
+ img_lqs = img2tensor(img_lqs)
+ imgs_gt = torch.stack(img_gts, dim=0)
+ imgs_lq = torch.stack(img_lqs, dim=0)
+
+ if self.normalize:
+ normalize(imgs_lq, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
+ normalize(imgs_gt, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
+
+ return {
+ 'lq': imgs_lq,
+ 'gt': imgs_gt,
+ 'folder': folder,
+ }
+
+ def __len__(self):
+ return len(self.folders)
diff --git a/basicsr/losses/__init__.py b/basicsr/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b184e74c861e6fca0c548692a9a949a6100b0aa
--- /dev/null
+++ b/basicsr/losses/__init__.py
@@ -0,0 +1,26 @@
+from copy import deepcopy
+
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import LOSS_REGISTRY
+from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
+ gradient_penalty_loss, r1_penalty)
+
+__all__ = [
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
+ 'r1_penalty', 'g_path_regularize'
+]
+
+
+def build_loss(opt):
+ """Build loss from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ loss_type = opt.pop('type')
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
+ return loss
diff --git a/basicsr/losses/loss_util.py b/basicsr/losses/loss_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..744eeb46d1f3b5a7b4553ca23237ddd9c899a698
--- /dev/null
+++ b/basicsr/losses/loss_util.py
@@ -0,0 +1,95 @@
+import functools
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are 'none', 'mean' and 'sum'.
+
+ Returns:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ else:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean'):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights. Default: None.
+ reduction (str): Same as built-in losses of PyTorch. Options are
+ 'none', 'mean' and 'sum'. Default: 'mean'.
+
+ Returns:
+ Tensor: Loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if weight is not specified or reduction is sum, just reduce the loss
+ if weight is None or reduction == 'sum':
+ loss = reduce_loss(loss, reduction)
+ # if reduction is mean, then compute mean over weight region
+ elif reduction == 'mean':
+ if weight.size(1) > 1:
+ weight = weight.sum()
+ else:
+ weight = weight.sum() * loss.size(1)
+ loss = loss.sum() / weight
+
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.5000)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, reduction='sum')
+ tensor(3.)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction)
+ return loss
+
+ return wrapper
diff --git a/basicsr/losses/losses.py b/basicsr/losses/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bcf272cfb756d99451a3005567ea4d4c9059067
--- /dev/null
+++ b/basicsr/losses/losses.py
@@ -0,0 +1,455 @@
+import math
+import lpips
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.archs.vgg_arch import VGGFeatureExtractor
+from basicsr.utils.registry import LOSS_REGISTRY
+from .loss_util import weighted_loss
+
+_reduction_modes = ['none', 'mean', 'sum']
+
+
+@weighted_loss
+def l1_loss(pred, target):
+ return F.l1_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ return F.mse_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def charbonnier_loss(pred, target, eps=1e-12):
+ return torch.sqrt((pred - target)**2 + eps)
+
+
+@LOSS_REGISTRY.register()
+class L1Loss(nn.Module):
+ """L1 (mean absolute error, MAE) loss.
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(L1Loss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class MSELoss(nn.Module):
+ """MSE (L2) loss.
+
+ Args:
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(MSELoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class CharbonnierLoss(nn.Module):
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
+ variant of L1Loss).
+
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
+ Super-Resolution".
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ eps (float): A value used to control the curvature near zero.
+ Default: 1e-12.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
+ super(CharbonnierLoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+ self.eps = eps
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class WeightedTVLoss(L1Loss):
+ """Weighted TV loss.
+
+ Args:
+ loss_weight (float): Loss weight. Default: 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
+
+ def forward(self, pred, weight=None):
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
+
+ loss = x_diff + y_diff
+
+ return loss
+
+
+@LOSS_REGISTRY.register()
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
+ feature layer (before relu5_4) will be extracted with weight
+ 1.0 in calculting losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+ loss will be calculated and the loss will multiplied by the
+ weight. Default: 1.0.
+ style_weight (float): If `style_weight > 0`, the style loss will be
+ calculated and the loss will multiplied by the weight.
+ Default: 0.
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+ """
+
+ def __init__(self,
+ layer_weights,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.,
+ criterion='l1'):
+ super(PerceptualLoss, self).__init__()
+ self.perceptual_weight = perceptual_weight
+ self.style_weight = style_weight
+ self.layer_weights = layer_weights
+ self.vgg = VGGFeatureExtractor(
+ layer_name_list=list(layer_weights.keys()),
+ vgg_type=vgg_type,
+ use_input_norm=use_input_norm,
+ range_norm=range_norm)
+
+ self.criterion_type = criterion
+ if self.criterion_type == 'l1':
+ self.criterion = torch.nn.L1Loss()
+ elif self.criterion_type == 'l2':
+ self.criterion = torch.nn.L2loss()
+ elif self.criterion_type == 'mse':
+ self.criterion = torch.nn.MSELoss(reduction='mean')
+ elif self.criterion_type == 'fro':
+ self.criterion = None
+ else:
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
+
+ def forward(self, x, gt):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate perceptual loss
+ if self.perceptual_weight > 0:
+ percep_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
+ else:
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+ percep_loss *= self.perceptual_weight
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_weight > 0:
+ style_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ style_loss += torch.norm(
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+ else:
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
+ gt_features[k])) * self.layer_weights[k]
+ style_loss *= self.style_weight
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ n, c, h, w = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
+
+
+@LOSS_REGISTRY.register()
+class LPIPSLoss(nn.Module):
+ def __init__(self,
+ loss_weight=1.0,
+ use_input_norm=True,
+ range_norm=False,):
+ super(LPIPSLoss, self).__init__()
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
+ self.loss_weight = loss_weight
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, pred, target):
+ if self.range_norm:
+ pred = (pred + 1) / 2
+ target = (target + 1) / 2
+ if self.use_input_norm:
+ pred = (pred - self.mean) / self.std
+ target = (target - self.mean) / self.std
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
+ return self.loss_weight * lpips_loss.mean()
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan':
+ self.loss = self._wgan_loss
+ elif self.gan_type == 'wgan_softplus':
+ self.loss = self._wgan_softplus_loss
+ elif self.gan_type == 'hinge':
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
+
+ def _wgan_loss(self, input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def _wgan_softplus_loss(self, input, target):
+ """wgan loss with soft plus. softplus is a smooth approximation to the
+ ReLU function.
+
+ In StyleGAN2, it is called:
+ Logistic loss for discriminator;
+ Non-saturating loss for generator.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type in ['wgan', 'wgan_softplus']:
+ return target_is_real
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+
+ Returns:
+ Tensor: GAN loss value.
+ """
+ if self.gan_type == 'hinge':
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ target_label = self.get_target_label(input, target_is_real)
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
+
+
+def r1_penalty(real_pred, real_img):
+ """R1 regularization for discriminator. The core idea is to
+ penalize the gradient on real data alone: when the
+ generator distribution produces the true data distribution
+ and the discriminator is equal to 0 on the data manifold, the
+ gradient penalty ensures that the discriminator cannot create
+ a non-zero gradient orthogonal to the data manifold without
+ suffering a loss in the GAN game.
+
+ Ref:
+ Eq. 9 in Which training methods for GANs do actually converge.
+ """
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+ return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+ """Calculate gradient penalty for wgan-gp.
+
+ Args:
+ discriminator (nn.Module): Network for the discriminator.
+ real_data (Tensor): Real input data.
+ fake_data (Tensor): Fake input data.
+ weight (Tensor): Weight tensor. Default: None.
+
+ Returns:
+ Tensor: A tensor for gradient penalty.
+ """
+
+ batch_size = real_data.size(0)
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+ # interpolate between real_data and fake_data
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+ disc_interpolates = discriminator(interpolates)
+ gradients = autograd.grad(
+ outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones_like(disc_interpolates),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+
+ if weight is not None:
+ gradients = gradients * weight
+
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+ if weight is not None:
+ gradients_penalty /= torch.mean(weight)
+
+ return gradients_penalty
diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19d55cc8321f124c918d78465b053aef67f13a33
--- /dev/null
+++ b/basicsr/metrics/__init__.py
@@ -0,0 +1,19 @@
+from copy import deepcopy
+
+from basicsr.utils.registry import METRIC_REGISTRY
+from .psnr_ssim import calculate_psnr, calculate_ssim
+
+__all__ = ['calculate_psnr', 'calculate_ssim']
+
+
+def calculate_metric(data, opt):
+ """Calculate metric from data and options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ metric_type = opt.pop('type')
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+ return metric
diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d18f0f7816431bed6af9d58319c6435bdf5c971
--- /dev/null
+++ b/basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from basicsr.utils.matlab_functions import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+ """Reorder images to 'HWC' order.
+
+ If the input_order is (h, w), return (h, w, 1);
+ If the input_order is (c, h, w), return (h, w, c);
+ If the input_order is (h, w, c), return as it is.
+
+ Args:
+ img (ndarray): Input image.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ If the input image shape is (h, w), input_order will not have
+ effects. Default: 'HWC'.
+
+ Returns:
+ ndarray: reordered image.
+ """
+
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
+ if len(img.shape) == 2:
+ img = img[..., None]
+ if input_order == 'CHW':
+ img = img.transpose(1, 2, 0)
+ return img
+
+
+def to_y_channel(img):
+ """Change to Y channel of YCbCr.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+
+ Returns:
+ (ndarray): Images with range [0, 255] (float type) without round.
+ """
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 3 and img.shape[2] == 3:
+ img = bgr2ycbcr(img, y_only=True)
+ img = img[..., None]
+ return img * 255.
diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbd950699c2495880236883861d9e199f900eae8
--- /dev/null
+++ b/basicsr/metrics/psnr_ssim.py
@@ -0,0 +1,128 @@
+import cv2
+import numpy as np
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the PSNR calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: psnr result.
+ """
+
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20. * np.log10(255. / np.sqrt(mse))
+
+
+def _ssim(img1, img2):
+ """Calculate SSIM (structural similarity) for one channel images.
+
+ It is called by func:`calculate_ssim`.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+ Returns:
+ float: ssim result.
+ """
+
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+ """Calculate SSIM (structural similarity).
+
+ Ref:
+ Image quality assessment: From error visibility to structural similarity
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the SSIM calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: ssim result.
+ """
+
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ ssims = []
+ for i in range(img1.shape[2]):
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
+ return np.array(ssims).mean()
diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00bde45f003698a5b15d3517ae47b59ef1d86e0c
--- /dev/null
+++ b/basicsr/models/__init__.py
@@ -0,0 +1,30 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ['build_model']
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with
+# '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
+
+
+def build_model(opt):
+ """Build model from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ model_type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
+ logger = get_root_logger()
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
+ return model
diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a6132f5d8806811cbd9cb0cb9db7a46c7e70966
--- /dev/null
+++ b/basicsr/models/base_model.py
@@ -0,0 +1,379 @@
+import logging
+import os
+import torch
+from collections import OrderedDict
+from copy import deepcopy
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from basicsr.models import lr_scheduler as lr_scheduler
+from basicsr.utils import get_root_logger
+from basicsr.utils.dist_util import master_only
+
+logger = logging.getLogger('basicsr')
+
+
+class BaseModel():
+ """Base model."""
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.is_train = opt['is_train']
+ self.schedulers = []
+ self.optimizers = []
+
+ def feed_data(self, data):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ pass
+
+ def save(self, epoch, current_iter):
+ """Save networks and training state."""
+ pass
+
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
+ """Validation function.
+
+ Args:
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
+ current_iter (int): Current iteration.
+ tb_logger (tensorboard logger): Tensorboard logger.
+ save_img (bool): Whether to save images. Default: False.
+ """
+ if self.opt['dist']:
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+ else:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def _initialize_best_metric_results(self, dataset_name):
+ """Initialize the best metric results dict for recording the best metric value and iteration."""
+ if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
+ return
+ elif not hasattr(self, 'best_metric_results'):
+ self.best_metric_results = dict()
+
+ # add a dataset record
+ record = dict()
+ for metric, content in self.opt['val']['metrics'].items():
+ better = content.get('better', 'higher')
+ init_val = float('-inf') if better == 'higher' else float('inf')
+ record[metric] = dict(better=better, val=init_val, iter=-1)
+ self.best_metric_results[dataset_name] = record
+
+ def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
+ if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
+ if val >= self.best_metric_results[dataset_name][metric]['val']:
+ self.best_metric_results[dataset_name][metric]['val'] = val
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
+ else:
+ if val <= self.best_metric_results[dataset_name][metric]['val']:
+ self.best_metric_results[dataset_name][metric]['val'] = val
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
+
+ def model_ema(self, decay=0.999):
+ net_g = self.get_bare_model(self.net_g)
+
+ net_g_params = dict(net_g.named_parameters())
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
+
+ for k in net_g_ema_params.keys():
+ net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def model_to_device(self, net):
+ """Model to device. It also warps models with DistributedDataParallel
+ or DataParallel.
+
+ Args:
+ net (nn.Module)
+ """
+ net = net.to(self.device)
+ if self.opt['dist']:
+ find_unused_parameters = self.opt.get('find_unused_parameters', False)
+ net = DistributedDataParallel(
+ net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
+ elif self.opt['num_gpu'] > 1:
+ net = DataParallel(net)
+ return net
+
+ def get_optimizer(self, optim_type, params, lr, **kwargs):
+ if optim_type == 'Adam':
+ optimizer = torch.optim.Adam(params, lr, **kwargs)
+ elif optim_type == 'AdamW':
+ optimizer = torch.optim.AdamW(params, lr, **kwargs)
+ elif optim_type == 'Adamax':
+ optimizer = torch.optim.Adamax(params, lr, **kwargs)
+ elif optim_type == 'SGD':
+ optimizer = torch.optim.SGD(params, lr, **kwargs)
+ elif optim_type == 'ASGD':
+ optimizer = torch.optim.ASGD(params, lr, **kwargs)
+ elif optim_type == 'RMSprop':
+ optimizer = torch.optim.RMSprop(params, lr, **kwargs)
+ elif optim_type == 'Rprop':
+ optimizer = torch.optim.Rprop(params, lr, **kwargs)
+ else:
+ raise NotImplementedError(f'optimizer {optim_type} is not supported yet.')
+ return optimizer
+
+ def setup_schedulers(self):
+ """Set up schedulers."""
+ train_opt = self.opt['train']
+ scheduler_type = train_opt['scheduler'].pop('type')
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
+ for optimizer in self.optimizers:
+ self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
+ elif scheduler_type == 'CosineAnnealingRestartLR':
+ for optimizer in self.optimizers:
+ self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
+ else:
+ raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
+
+ def get_bare_model(self, net):
+ """Get bare model, especially under wrapping with
+ DistributedDataParallel or DataParallel.
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net = net.module
+ return net
+
+ @master_only
+ def print_network(self, net):
+ """Print the str and parameter number of a network.
+
+ Args:
+ net (nn.Module)
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
+ else:
+ net_cls_str = f'{net.__class__.__name__}'
+
+ net = self.get_bare_model(net)
+ net_str = str(net)
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
+
+ logger = get_root_logger()
+ logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
+ logger.info(net_str)
+
+ def _set_lr(self, lr_groups_l):
+ """Set learning rate for warm-up.
+
+ Args:
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
+ """
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
+ param_group['lr'] = lr
+
+ def _get_init_lr(self):
+ """Get the initial lr, which is set by the scheduler.
+ """
+ init_lr_groups_l = []
+ for optimizer in self.optimizers:
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
+ return init_lr_groups_l
+
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warm-up iter numbers. -1 for no warm-up.
+ Default: -1.
+ """
+ if current_iter > 1:
+ for scheduler in self.schedulers:
+ scheduler.step()
+ # set up warm-up learning rate
+ if current_iter < warmup_iter:
+ # get initial lr for each group
+ init_lr_g_l = self._get_init_lr()
+ # modify warming-up learning rates
+ # currently only support linearly warm up
+ warm_up_lr_l = []
+ for init_lr_g in init_lr_g_l:
+ warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
+ # set learning rate
+ self._set_lr(warm_up_lr_l)
+
+ def get_current_learning_rate(self):
+ return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
+
+ @master_only
+ def save_network(self, net, net_label, current_iter, param_key='params'):
+ """Save networks.
+
+ Args:
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
+ net_label (str): Network label.
+ current_iter (int): Current iter number.
+ param_key (str | list[str]): The parameter key(s) to save network.
+ Default: 'params'.
+ """
+ if current_iter == -1:
+ current_iter = 'latest'
+ save_filename = f'{net_label}_{current_iter}.pth'
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
+
+ net = net if isinstance(net, list) else [net]
+ param_key = param_key if isinstance(param_key, list) else [param_key]
+ assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
+
+ save_dict = {}
+ for net_, param_key_ in zip(net, param_key):
+ net_ = self.get_bare_model(net_)
+ state_dict = net_.state_dict()
+ for key, param in state_dict.items():
+ if key.startswith('module.'): # remove unnecessary 'module.'
+ key = key[7:]
+ state_dict[key] = param.cpu()
+ save_dict[param_key_] = state_dict
+
+ torch.save(save_dict, save_path)
+
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
+ """Print keys with differnet name or different size when loading models.
+
+ 1. Print keys with differnet names.
+ 2. If strict=False, print the same key but with different tensor size.
+ It also ignore these keys with different sizes (not load).
+
+ Args:
+ crt_net (torch model): Current network.
+ load_net (dict): Loaded network.
+ strict (bool): Whether strictly loaded. Default: True.
+ """
+ crt_net = self.get_bare_model(crt_net)
+ crt_net = crt_net.state_dict()
+ crt_net_keys = set(crt_net.keys())
+ load_net_keys = set(load_net.keys())
+
+ logger = get_root_logger()
+ if crt_net_keys != load_net_keys:
+ logger.warning('Current net - loaded net:')
+ for v in sorted(list(crt_net_keys - load_net_keys)):
+ logger.warning(f' {v}')
+ logger.warning('Loaded net - current net:')
+ for v in sorted(list(load_net_keys - crt_net_keys)):
+ logger.warning(f' {v}')
+
+ # check the size for the same keys
+ if not strict:
+ common_keys = crt_net_keys & load_net_keys
+ for k in common_keys:
+ if crt_net[k].size() != load_net[k].size():
+ logger.warning(f'Size different, ignore [{k}]: crt_net: '
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
+ load_net[k + '.ignore'] = load_net.pop(k)
+
+ def load_network(self, net, load_path, strict=True, param_key='params'):
+ """Load network.
+
+ Args:
+ load_path (str): The path of networks to be loaded.
+ net (nn.Module): Network.
+ strict (bool): Whether strictly loaded.
+ param_key (str): The parameter key of loaded network. If set to
+ None, use the root 'path'.
+ Default: 'params'.
+ """
+ logger = get_root_logger()
+ net = self.get_bare_model(net)
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage, weights_only=True)
+ if param_key is not None:
+ if param_key not in load_net and 'params' in load_net:
+ param_key = 'params'
+ logger.info('Loading: params_ema does not exist, use params.')
+ load_net = load_net[param_key]
+ logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
+ # remove unnecessary 'module.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith('module.'):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ self._print_different_keys_loading(net, load_net, strict)
+ net.load_state_dict(load_net, strict=strict)
+
+ @master_only
+ def save_training_state(self, epoch, current_iter):
+ """Save training states during training, which will be used for
+ resuming.
+
+ Args:
+ epoch (int): Current epoch.
+ current_iter (int): Current iteration.
+ """
+ if current_iter != -1:
+ state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
+ for o in self.optimizers:
+ state['optimizers'].append(o.state_dict())
+ for s in self.schedulers:
+ state['schedulers'].append(s.state_dict())
+ save_filename = f'{current_iter}.state'
+ save_path = os.path.join(self.opt['path']['training_states'], save_filename)
+
+ # avoid occasional writing errors
+ retry = 3
+ while retry > 0:
+ try:
+ torch.save(state, save_path)
+ except Exception as e:
+ logger = get_root_logger()
+ logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
+ time.sleep(1)
+ else:
+ break
+ finally:
+ retry -= 1
+ if retry == 0:
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
+ # raise IOError(f'Cannot save {save_path}.')
+
+ def resume_training(self, resume_state):
+ """Reload the optimizers and schedulers for resumed training.
+
+ Args:
+ resume_state (dict): Resume state.
+ """
+ resume_optimizers = resume_state['optimizers']
+ resume_schedulers = resume_state['schedulers']
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
+ for i, o in enumerate(resume_optimizers):
+ self.optimizers[i].load_state_dict(o)
+ for i, s in enumerate(resume_schedulers):
+ self.schedulers[i].load_state_dict(s)
+
+ def reduce_loss_dict(self, loss_dict):
+ """reduce loss dict.
+
+ In distributed training, it averages the losses among different GPUs .
+
+ Args:
+ loss_dict (OrderedDict): Loss dict.
+ """
+ with torch.no_grad():
+ if self.opt['dist']:
+ keys = []
+ losses = []
+ for name, value in loss_dict.items():
+ keys.append(name)
+ losses.append(value)
+ losses = torch.stack(losses, 0)
+ torch.distributed.reduce(losses, dst=0)
+ if self.opt['rank'] == 0:
+ losses /= self.opt['world_size']
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
+
+ log_dict = OrderedDict()
+ for name, value in loss_dict.items():
+ log_dict[name] = value.mean().item()
+
+ return log_dict
diff --git a/basicsr/models/keep_gan_model.py b/basicsr/models/keep_gan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..528247dfc07f736708509fe0e53bc05df639db75
--- /dev/null
+++ b/basicsr/models/keep_gan_model.py
@@ -0,0 +1,303 @@
+from collections import OrderedDict
+import torch
+import torch.nn.functional as F
+import pdb
+
+from einops import rearrange
+
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import MODEL_REGISTRY
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.archs.arch_util import flow_warp, resize_flow
+
+from .video_recurrent_model import VideoRecurrentModel
+
+
+@MODEL_REGISTRY.register()
+class KEEPGANModel(VideoRecurrentModel):
+ """KEEPGAN Model.
+ """
+ def init_training_settings(self):
+ self.net_g.train()
+ train_opt = self.opt['train']
+ logger = get_root_logger()
+
+ # # load pretrained VQGAN models
+ # load_path = self.opt['path'].get('pretrain_network_vqgan', None)
+ # if load_path is not None:
+ # param_key = self.opt['path'].get('param_key_vqgan', 'params')
+ # self.load_network(self.net_g, load_path, False, param_key)
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger.info(
+ f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(
+ self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get(
+ 'strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained weights
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ self.load_network(self.net_d, load_path,
+ self.opt['path'].get('strict_load_d', True))
+ self.net_d.train()
+
+ # define losses.
+ self.hq_feat_loss = train_opt.get('use_hq_feat_loss', False)
+ self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
+ self.cross_entropy_loss = train_opt.get('cross_entropy_loss', False)
+ self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
+
+ if self.cross_entropy_loss:
+ self.generate_idx_gt = True
+ assert self.opt.get(
+ 'network_vqgan', None) is not None, f'Shoule have network_vqgan config or pre-calculated latent code.'
+ self.hq_vqgan_fix = build_network(
+ self.opt['network_vqgan']).to(self.device)
+ self.hq_vqgan_fix.eval()
+ for param in self.hq_vqgan_fix.parameters():
+ param.requires_grad = False
+ # load_path = self.opt['path'].get('pretrain_network_vqgan', None)
+ # assert load_path != None, "Should load pre-trained VQGAN"
+ # self.load_network(self.hq_vqgan_fix, load_path, strict=False)
+ else:
+ self.generate_idx_gt = False
+ logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
+
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.perceptual_type = train_opt['perceptual_opt']['type']
+ self.cri_perceptual = build_loss(
+ train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('temporal_opt'):
+ self.temporal_type = train_opt.get('temporal_warp_type', 'GT')
+ self.cri_temporal = build_loss(
+ train_opt['temporal_opt']).to(self.device)
+ else:
+ self.cri_temporal = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ logger = get_root_logger()
+
+ optim_names, freezed_names = [], []
+ # optimizer g
+ optim_params_g = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params_g.append(v)
+ optim_names.append(k)
+ else:
+ freezed_names.append(k)
+
+ logger.warning(f'--------------- Optimizing Params ---------------.')
+ for k in optim_names:
+ logger.warning(f'Params {k} will be optimized.')
+ logger.warning(f'--------------- Freezing Params ---------------.')
+ for k in freezed_names:
+ logger.warning(f'Params {k} will be freezed.')
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(
+ optim_type, optim_params_g, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(
+ optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+ self.optimizer_g.zero_grad()
+
+ if self.generate_idx_gt:
+ with torch.no_grad():
+ b, f, c, h, w = self.gt.shape
+ x = self.hq_vqgan_fix.encoder(self.gt.reshape(-1, c, h, w))
+ _, _, quant_stats = self.hq_vqgan_fix.quantize(x)
+ min_encoding_indices = quant_stats['min_encoding_indices']
+ self.idx_gt = min_encoding_indices.view(b*f, -1)
+
+ if self.hq_feat_loss or self.cross_entropy_loss:
+ self.output, logits, lq_feat, gen_feat_dict = self.net_g(
+ self.lq, detach_16=True, early_feat=True)
+ else:
+ self.output, gen_feat_dict = self.net_g(
+ self.lq, detach_16=True, early_feat=False)
+ if len(gen_feat_dict) == 0:
+ gen_feat_dict['HR'] = self.output
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ # hq_feat_loss
+ if self.hq_feat_loss: # codebook loss
+ code_h = lq_feat.shape[-1]
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(
+ self.idx_gt, shape=[b*f, code_h, code_h, 256])
+ l_feat_encoder = torch.mean(
+ (quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
+ l_g_total += l_feat_encoder
+ loss_dict['l_feat_encoder'] = l_feat_encoder
+
+ # cross_entropy_loss
+ if self.cross_entropy_loss:
+ # b(hw)n -> bn(hw)
+ cross_entropy_loss = F.cross_entropy(logits.permute(
+ 0, 2, 1), self.idx_gt) * self.entropy_loss_weight
+ l_g_total += cross_entropy_loss
+ loss_dict['l_cross_entropy'] = cross_entropy_loss
+
+ # Temporal consistency loss
+ if self.cri_temporal:
+ assert len(
+ gen_feat_dict) != 0, "Empty features for temporal regularization."
+ with torch.no_grad():
+ if self.temporal_type == 'GT':
+ flows = self.net_g.module.get_flow(self.gt).detach()
+ flows = rearrange(flows, "b f c h w -> (b f) c h w")
+ elif self.temporal_type == 'HR':
+ flows = self.net_g.module.get_flow(self.output).detach()
+ flows = rearrange(flows, "b f c h w -> (b f) c h w")
+ elif self.temporal_type == 'Diff':
+ gt_flows = self.net_g.module.get_flow(self.gt).detach()
+ gt_flows = rearrange(gt_flows, "b f c h w -> (b f) c h w")
+ hr_flows = self.net_g.module.get_flow(self.output).detach()
+ hr_flows = rearrange(hr_flows, "b f c h w -> (b f) c h w")
+ else:
+ raise ValueError(
+ f'Unsupported temporal mode: {self.temporal_type}.')
+
+ l_temporal = 0
+ for f_size, feat in gen_feat_dict.items():
+ b, f, c, h, w = feat.shape
+
+ if self.temporal_type == 'GT' or self.temporal_type == 'HR':
+ flow = resize_flow(flows, 'shape', [h, w]) # B*(T-1) 2 H W
+ flow = rearrange(flow, "b c h w -> b h w c")
+ prev_feat = feat[:, :-1, ...].reshape(-1, c, h, w)
+ curr_feat = feat[:, 1:, ...].reshape(-1, c, h, w)
+ warp_feat = flow_warp(prev_feat, flow)
+ l_temporal += self.cri_temporal(curr_feat, warp_feat)
+ elif self.temporal_type == 'Diff':
+ gt_flow = resize_flow(gt_flows, 'shape', [
+ h, w]) # B*(T-1) 2 H W
+ gt_flow = rearrange(gt_flow, "b c h w -> b h w c")
+ hr_flow = resize_flow(hr_flows, 'shape', [
+ h, w]) # B*(T-1) 2 H W
+ hr_flow = rearrange(hr_flow, "b c h w -> b h w c")
+
+ prev_feat = feat[:, :-1, ...].reshape(-1, c, h, w)
+ curr_feat = feat[:, 1:, ...].reshape(-1, c, h, w)
+ gt_warp_feat = flow_warp(prev_feat, gt_flow)
+ hr_warp_feat = flow_warp(prev_feat, hr_flow)
+ l_temporal += self.cri_temporal(gt_warp_feat, hr_warp_feat)
+
+ l_g_total += l_temporal
+ loss_dict['l_temporal'] = l_temporal
+
+ # pixel loss
+ if self.cri_pix:
+ l_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_pix
+ loss_dict['l_pix'] = l_pix
+
+ # perceptual loss
+ if self.cri_perceptual:
+ B, T, C, H, W = self.gt.shape
+ if self.perceptual_type == 'PerceptualLoss':
+ l_percep, l_style = self.cri_perceptual(
+ self.output.view(-1, C, H, W), self.gt.view(-1, C, H, W))
+ if l_percep is not None:
+ l_g_total += l_percep
+ loss_dict['l_percep'] = l_percep
+ if l_style is not None:
+ l_g_total += l_style
+ loss_dict['l_style'] = l_style
+ elif self.perceptual_type == 'LPIPSLoss':
+ l_percep = self.cri_perceptual(
+ self.output.view(-1, C, H, W), self.gt.view(-1, C, H, W))
+ l_g_total += l_percep
+ loss_dict['l_percep'] = l_percep
+
+ # gan loss
+ if current_iter > self.net_d_start_iter:
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ # optimize net_d
+ if current_iter > self.net_d_start_iter:
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+ self.optimizer_d.zero_grad()
+
+ # real
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'net_g_ema'):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g',
+ current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/keep_model.py b/basicsr/models/keep_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..51f8b736b666f83613de2713710bd0047520704d
--- /dev/null
+++ b/basicsr/models/keep_model.py
@@ -0,0 +1,242 @@
+from collections import OrderedDict
+import torch
+import torch.nn.functional as F
+import pdb
+
+from einops import rearrange
+
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import MODEL_REGISTRY
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.archs.arch_util import flow_warp, resize_flow
+
+from .video_recurrent_model import VideoRecurrentModel
+
+
+@MODEL_REGISTRY.register()
+class KEEPModel(VideoRecurrentModel):
+ """KEEP Model.
+ """
+
+ def init_training_settings(self):
+ self.net_g.train()
+ train_opt = self.opt['train']
+ logger = get_root_logger()
+
+ # # load pretrained VQGAN models
+ # load_path = self.opt['path'].get('pretrain_network_vqgan', None)
+ # if load_path is not None:
+ # param_key = self.opt['path'].get('param_key_vqgan', 'params')
+ # self.load_network(self.net_g, load_path, False, param_key)
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger.info(
+ f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(
+ self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get(
+ 'strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define losses.
+ self.hq_feat_loss = train_opt.get('use_hq_feat_loss', False)
+ self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
+ self.cross_entropy_loss = train_opt.get('cross_entropy_loss', False)
+ self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
+
+ if self.cross_entropy_loss:
+ self.generate_idx_gt = True
+ assert self.opt.get(
+ 'network_vqgan', None) is not None, f'Shoule have network_vqgan config or pre-calculated latent code.'
+ self.hq_vqgan_fix = build_network(
+ self.opt['network_vqgan']).to(self.device)
+ self.hq_vqgan_fix.eval()
+ for param in self.hq_vqgan_fix.parameters():
+ param.requires_grad = False
+ # load_path = self.opt['path'].get('pretrain_network_vqgan', None)
+ # assert load_path != None, "Should load pre-trained VQGAN"
+ # self.load_network(self.hq_vqgan_fix, load_path, strict=False)
+ else:
+ self.generate_idx_gt = False
+ logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
+
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.perceptual_type = train_opt['perceptual_opt']['type']
+ self.cri_perceptual = build_loss(
+ train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('temporal_opt'):
+ self.temporal_type = train_opt.get('temporal_warp_type', 'GT')
+ self.cri_temporal = build_loss(
+ train_opt['temporal_opt']).to(self.device)
+ else:
+ self.cri_temporal = None
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ logger = get_root_logger()
+
+ optim_names, freezed_names = [], []
+ # optimizer g
+ optim_params_g = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params_g.append(v)
+ optim_names.append(k)
+ else:
+ freezed_names.append(k)
+
+ logger.warning(f'--------------- Optimizing Params ---------------.')
+ for k in optim_names:
+ logger.warning(f'Params {k} will be optimized.')
+ logger.warning(f'--------------- Freezing Params ---------------.')
+ for k in freezed_names:
+ logger.warning(f'Params {k} will be freezed.')
+
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(
+ optim_type, optim_params_g, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ self.optimizer_g.zero_grad()
+
+ if self.generate_idx_gt:
+ with torch.no_grad():
+ b, f, c, h, w = self.gt.shape
+ x = self.hq_vqgan_fix.encoder(self.gt.reshape(-1, c, h, w))
+ _, _, quant_stats = self.hq_vqgan_fix.quantize(x)
+ min_encoding_indices = quant_stats['min_encoding_indices']
+ self.idx_gt = min_encoding_indices.view(b*f, -1)
+
+ if self.hq_feat_loss or self.cross_entropy_loss:
+ self.output, logits, lq_feat, gen_feat_dict = self.net_g(
+ self.lq, detach_16=True, early_feat=True)
+ else:
+ self.output, gen_feat_dict = self.net_g(
+ self.lq, detach_16=True, early_feat=False)
+ if len(gen_feat_dict) == 0:
+ gen_feat_dict['HR'] = self.output
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ # hq_feat_loss
+ if self.hq_feat_loss: # codebook loss
+ code_h = lq_feat.shape[-1]
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(
+ self.idx_gt, shape=[b*f, code_h, code_h, 256])
+ l_feat_encoder = torch.mean(
+ (quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
+ l_g_total += l_feat_encoder
+ loss_dict['l_feat_encoder'] = l_feat_encoder
+
+ # cross_entropy_loss
+ if self.cross_entropy_loss:
+ # b(hw)n -> bn(hw)
+ cross_entropy_loss = F.cross_entropy(logits.permute(
+ 0, 2, 1), self.idx_gt) * self.entropy_loss_weight
+ l_g_total += cross_entropy_loss
+ loss_dict['l_cross_entropy'] = cross_entropy_loss
+
+ # Temporal consistency loss
+ if self.cri_temporal:
+ assert len(
+ gen_feat_dict) != 0, "Empty features for temporal regularization."
+ with torch.no_grad():
+ if self.temporal_type == 'GT':
+ flows = self.net_g.module.get_flow(self.gt).detach()
+ flows = rearrange(flows, "b f c h w -> (b f) c h w")
+ elif self.temporal_type == 'HR':
+ flows = self.net_g.module.get_flow(self.output).detach()
+ flows = rearrange(flows, "b f c h w -> (b f) c h w")
+ elif self.temporal_type == 'Diff':
+ gt_flows = self.net_g.module.get_flow(self.gt).detach()
+ gt_flows = rearrange(gt_flows, "b f c h w -> (b f) c h w")
+ hr_flows = self.net_g.module.get_flow(self.output).detach()
+ hr_flows = rearrange(hr_flows, "b f c h w -> (b f) c h w")
+ else:
+ raise ValueError(
+ f'Unsupported temporal mode: {self.temporal_type}.')
+
+ l_temporal = 0
+ for f_size, feat in gen_feat_dict.items():
+ b, f, c, h, w = feat.shape
+
+ if self.temporal_type == 'GT' or self.temporal_type == 'HR':
+ flow = resize_flow(flows, 'shape', [h, w]) # B*(T-1) 2 H W
+ flow = rearrange(flow, "b c h w -> b h w c")
+ prev_feat = feat[:, :-1, ...].view(-1, c, h, w)
+ curr_feat = feat[:, 1:, ...].view(-1, c, h, w)
+ warp_feat = flow_warp(prev_feat, flow)
+ l_temporal += self.cri_temporal(curr_feat, warp_feat)
+ elif self.temporal_type == 'Diff':
+ gt_flow = resize_flow(gt_flows, 'shape', [
+ h, w]) # B*(T-1) 2 H W
+ gt_flow = rearrange(gt_flow, "b c h w -> b h w c")
+ hr_flow = resize_flow(hr_flows, 'shape', [
+ h, w]) # B*(T-1) 2 H W
+ hr_flow = rearrange(hr_flow, "b c h w -> b h w c")
+
+ prev_feat = feat[:, :-1, ...].view(-1, c, h, w)
+ curr_feat = feat[:, 1:, ...].view(-1, c, h, w)
+ gt_warp_feat = flow_warp(prev_feat, gt_flow)
+ hr_warp_feat = flow_warp(prev_feat, hr_flow)
+ l_temporal += self.cri_temporal(gt_warp_feat, hr_warp_feat)
+
+ l_g_total += l_temporal
+ loss_dict['l_temporal'] = l_temporal
+
+ # pixel loss
+ if self.cri_pix:
+ l_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_pix
+ loss_dict['l_pix'] = l_pix
+
+ # perceptual loss
+ if self.cri_perceptual:
+ B, T, C, H, W = self.gt.shape
+ if self.perceptual_type == 'PerceptualLoss':
+ l_percep, l_style = self.cri_perceptual(
+ self.output.view(-1, C, H, W), self.gt.view(-1, C, H, W))
+ if l_percep is not None:
+ l_g_total += l_percep
+ loss_dict['l_percep'] = l_percep
+ if l_style is not None:
+ l_g_total += l_style
+ loss_dict['l_style'] = l_style
+ elif self.perceptual_type == 'LPIPSLoss':
+ l_percep = self.cri_perceptual(
+ self.output.view(-1, C, H, W), self.gt.view(-1, C, H, W))
+ l_g_total += l_percep
+ loss_dict['l_percep'] = l_percep
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e1c6c7a74f5233accda52370f92681d3d3cecf
--- /dev/null
+++ b/basicsr/models/lr_scheduler.py
@@ -0,0 +1,96 @@
+import math
+from collections import Counter
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepRestartLR(_LRScheduler):
+ """ MultiStep with restarts learning rate scheme.
+
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ milestones (list): Iterations that will decrease learning rate.
+ gamma (float): Decrease ratio. Default: 0.1.
+ restarts (list): Restart iterations. Default: [0].
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ self.restarts = restarts
+ self.restart_weights = restart_weights
+ assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch in self.restarts:
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
+
+
+def get_position_from_periods(iteration, cumulative_period):
+ """Get the position from a period list.
+
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_period = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 2.
+
+ Args:
+ iteration (int): Current iteration.
+ cumulative_period (list[int]): Cumulative period list.
+
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_period):
+ if iteration <= period:
+ return i
+
+
+class CosineAnnealingRestartLR(_LRScheduler):
+ """ Cosine annealing with restarts learning rate scheme.
+
+ An example of config:
+ periods = [10, 10, 10, 10]
+ restart_weights = [1, 0.5, 0.5, 0.5]
+ eta_min=1e-7
+
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
+ scheduler will restart with the weights in restart_weights.
+
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ periods (list): Period for each cosine anneling cycle.
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ eta_min (float): The minimum lr. Default: 0.
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
+ self.periods = periods
+ self.restart_weights = restart_weights
+ self.eta_min = eta_min
+ assert (len(self.periods) == len(
+ self.restart_weights)), 'periods and restart_weights should have the same length.'
+ self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
+ current_period = self.periods[idx]
+
+ return [
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
+ (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
+ for base_lr in self.base_lrs
+ ]
diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f4b9fe05fc8d144ec9973c2ec5fd3aa201718dd
--- /dev/null
+++ b/basicsr/models/sr_model.py
@@ -0,0 +1,209 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+from .base_model import BaseModel
+
+@MODEL_REGISTRY.register()
+class SRModel(BaseModel):
+ """Base SR model for single image super-resolution."""
+
+ def __init__(self, opt):
+ super(SRModel, self).__init__(opt)
+
+ # define network
+ self.net_g = build_network(opt['network_g'])
+ self.net_g = self.model_to_device(self.net_g)
+ self.print_network(self.net_g)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_g', 'params')
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
+
+ if self.is_train:
+ self.init_training_settings()
+
+ def init_training_settings(self):
+ self.net_g.train()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if self.cri_pix is None and self.cri_perceptual is None:
+ raise ValueError('Both pixel and perceptual losses are None.')
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ optim_params = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def feed_data(self, data):
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+
+ def optimize_parameters(self, current_iter):
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_total = 0
+ loss_dict = OrderedDict()
+ # pixel loss
+ if self.cri_pix:
+ l_pix = self.cri_pix(self.output, self.gt)
+ l_total += l_pix
+ loss_dict['l_pix'] = l_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_percep, l_style = self.cri_perceptual(self.output, self.gt)
+ if l_percep is not None:
+ l_total += l_percep
+ loss_dict['l_percep'] = l_percep
+ if l_style is not None:
+ l_total += l_style
+ loss_dict['l_style'] = l_style
+
+ l_total.backward()
+ self.optimizer_g.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def test(self):
+ if hasattr(self, 'ema_decay'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ self.output = self.net_g_ema(self.lq)
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ self.output = self.net_g(self.lq)
+ self.net_g.train()
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ metric_data = dict(img1=sr_img, img2=gt_img)
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}\n'
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['lq'] = self.lq.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ if hasattr(self, 'gt'):
+ out_dict['gt'] = self.gt.detach().cpu()
+ return out_dict
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'ema_decay'):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/video_recurrent_model.py b/basicsr/models/video_recurrent_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c080c2662a5df157d4368e7e1fc6b1d8accedee0
--- /dev/null
+++ b/basicsr/models/video_recurrent_model.py
@@ -0,0 +1,308 @@
+import torch
+from collections import Counter
+from os import path as osp
+from torch import distributed as dist
+from tqdm import tqdm
+import cv2
+import os
+
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class VideoRecurrentModel(SRModel):
+ """Video Recurrent SR model (merged with VideoBaseModel)."""
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ flow_lr_mul = train_opt.get('flow_lr_mul', 1)
+ logger = get_root_logger()
+ logger.info(
+ f'Multiple the learning rate for flow network with {flow_lr_mul}.')
+ if flow_lr_mul == 1:
+ optim_params = self.net_g.parameters()
+ else: # separate flow params and normal params for different lr
+ normal_params = []
+ flow_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name:
+ flow_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ {
+ 'params': flow_params,
+ 'lr': train_opt['optim_g']['lr'] * flow_lr_mul
+ },
+ ]
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(
+ optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def optimize_parameters(self, current_iter):
+ if hasattr(self, 'fix_flow_iter') and self.fix_flow_iter:
+ logger = get_root_logger()
+ if current_iter == 1:
+ logger.info(
+ f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name or 'edvr' in name:
+ param.requires_grad_(False)
+ elif current_iter == self.fix_flow_iter:
+ logger.warning('Train all the parameters.')
+ self.net_g.requires_grad_(True)
+
+ super(VideoRecurrentModel, self).optimize_parameters(current_iter)
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset = dataloader.dataset
+ dataset_name = dataset.opt['name']
+ with_metrics = self.opt['val']['metrics'] is not None
+ save_video = self.opt['val'].get('save_video', False)
+ # initialize self.metric_results
+ # It is a dict: {
+ # 'folder1': tensor (num_frame x len(metrics)),
+ # 'folder2': tensor (num_frame x len(metrics))
+ # }
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {}
+ num_frame_each_folder = Counter(dataset.data_info['folder'])
+ for folder, num_frame in num_frame_each_folder.items():
+ self.metric_results[folder] = torch.zeros(
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ # initialize the best metric results
+ self._initialize_best_metric_results(dataset_name)
+ # zero self.metric_results
+ rank, world_size = get_dist_info()
+ if with_metrics:
+ for _, tensor in self.metric_results.items():
+ tensor.zero_()
+
+ metric_data = dict()
+ num_folders = len(dataset)
+ num_pad = (world_size - (num_folders % world_size)) % world_size
+ if rank == 0:
+ pbar = tqdm(total=len(dataset), unit='folder')
+ # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
+ # (To avoid wait-dead)
+ for i in range(rank, num_folders + num_pad, world_size):
+ idx = min(i, num_folders - 1)
+ val_data = dataset[idx]
+ folder = val_data['folder']
+
+ # compute outputs
+ val_data['lq'].unsqueeze_(0)
+ val_data['gt'].unsqueeze_(0)
+ self.feed_data(val_data)
+ val_data['lq'].squeeze_(0)
+ val_data['gt'].squeeze_(0)
+
+ self.test()
+ visuals = self.get_current_visuals()
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ if 'gt' in visuals:
+ del self.gt
+ torch.cuda.empty_cache()
+
+ if hasattr(self, 'center_frame_only') and self.center_frame_only:
+ visuals['result'] = visuals['result'].unsqueeze(1)
+ if 'gt' in visuals:
+ visuals['gt'] = visuals['gt'].unsqueeze(1)
+
+ # # For EDVR
+ # result = visuals['result']
+ # result_img = tensor2img([result])
+
+ # if save_img:
+ # if self.opt['is_train']:
+ # raise NotImplementedError(
+ # 'saving image is not supported during training.')
+ # else:
+ # img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ # f"{idx:08d}.png")
+ # # image name only for REDS dataset
+ # imwrite(result_img, img_path)
+
+ # evaluate
+ if i < num_folders:
+ video_writer = None
+ for idx in range(visuals['result'].size(1)):
+ result = visuals['result'][0, idx, :, :, :]
+ result_img = tensor2img(
+ [result], min_max=(-1, 1)) # uint8, bgr
+ metric_data['img1'] = result_img
+ if 'gt' in visuals:
+ gt = visuals['gt'][0, idx, :, :, :]
+ gt_img = tensor2img(
+ [gt], min_max=(-1, 1)) # uint8, bgr
+ metric_data['img2'] = gt_img
+
+ if save_img:
+ if self.opt['is_train']:
+ raise NotImplementedError(
+ 'saving image is not supported during training.')
+ else:
+ if hasattr(self, 'center_frame_only') and self.center_frame_only: # vimeo-90k
+ clip_ = val_data['lq_path'].split('/')[-3]
+ seq_ = val_data['lq_path'].split('/')[-2]
+ name_ = f'{clip_}_{seq_}'
+ img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f"{name_}_{self.opt['name']}.png")
+ else: # others
+ img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f"{idx:08d}.png")
+ imwrite(result_img, img_path)
+
+ if save_video:
+ if self.opt['is_train']:
+ raise NotImplementedError(
+ 'saving image is not supported during training.')
+ else:
+ if video_writer is None:
+ video_output_path = osp.join(self.opt['path']['visualization'], dataset_name+'_video',
+ f"{folder}.mp4")
+ dir_name = osp.abspath(
+ osp.dirname(video_output_path))
+ os.makedirs(dir_name, exist_ok=True)
+ frame_rate = 15
+ h, w = result_img.shape[:2]
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ video_writer = cv2.VideoWriter(video_output_path, fourcc,
+ frame_rate, (w, h))
+ video_writer.write(result_img)
+
+ # calculate metrics
+ if with_metrics:
+ for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
+ result = calculate_metric(metric_data, opt_)
+ self.metric_results[folder][idx,
+ metric_idx] += result
+
+ if save_video:
+ cv2.destroyAllWindows()
+ video_writer.release()
+
+ # progress bar
+ if rank == 0:
+ for _ in range(world_size):
+ pbar.update(1)
+ pbar.set_description(f'Folder: {folder}')
+
+ if rank == 0:
+ pbar.close()
+
+ if with_metrics:
+ if self.opt['dist']:
+ # collect data among GPUs
+ for _, tensor in self.metric_results.items():
+ dist.reduce(tensor, 0)
+ dist.barrier()
+
+ if rank == 0:
+ self._log_validation_metric_values(
+ current_iter, dataset_name, tb_logger)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ logger = get_root_logger()
+ logger.warning(
+ 'nondist_validation is not implemented. Run dist_validation.')
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ # ----------------- calculate the average values for each folder, and for each metric ----------------- #
+ # average all frames for each sub-folder
+ # metric_results_avg is a dict:{
+ # 'folder1': tensor (len(metrics)),
+ # 'folder2': tensor (len(metrics))
+ # }
+ metric_results_avg = {
+ folder: torch.mean(tensor, dim=0).cpu()
+ for (folder, tensor) in self.metric_results.items()
+ }
+ # total_avg_results is a dict: {
+ # 'metric1': float,
+ # 'metric2': float
+ # }
+ total_avg_results = {
+ metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ for folder, tensor in metric_results_avg.items():
+ for idx, metric in enumerate(total_avg_results.keys()):
+ total_avg_results[metric] += metric_results_avg[folder][idx].item()
+ # average among folders
+ for metric in total_avg_results.keys():
+ total_avg_results[metric] /= len(metric_results_avg)
+ # update the best metric result
+ self._update_best_metric_result(
+ dataset_name, metric, total_avg_results[metric], current_iter)
+
+ # ------------------------------------------ log the metric ------------------------------------------ #
+ log_str = f'Validation {dataset_name}\n'
+ for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
+ log_str += f'\t # {metric}: {value:.4f}\n'
+ for folder, tensor in metric_results_avg.items():
+ log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}\n'
+ if hasattr(self, 'best_metric_results'):
+ log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+ log_str += '\n'
+
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+ for folder, tensor in metric_results_avg.items():
+ tb_logger.add_scalar(
+ f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
+
+ def test(self):
+ n = self.lq.size(1)
+ self.net_g.eval()
+
+ flip_seq = self.opt['val'].get('flip_seq', False)
+ self.center_frame_only = self.opt['val'].get('center_frame_only', False)
+
+ if flip_seq:
+ self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1)
+
+ with torch.no_grad():
+ video_length = self.lq.shape[1]
+ fix_length = 20
+ if video_length > fix_length:
+ output = []
+ for start_idx in range(0, video_length, fix_length):
+ end_idx = min(start_idx + fix_length, video_length)
+ if end_idx - start_idx == 1:
+ output.append(self.net_g(
+ self.lq[:, [start_idx, start_idx], ...])[:, 0:1, ...])
+ else:
+ output.append(self.net_g(
+ self.lq[:, start_idx:end_idx, ...]))
+ self.output = torch.cat(output, dim=1)
+ assert self.output.shape[1] == video_length, "Differer number of frames"
+ else:
+ self.output = self.net_g(self.lq)
+
+ if flip_seq:
+ output_1 = self.output[:, :n, :, :, :]
+ output_2 = self.output[:, n:, :, :, :].flip(1)
+ self.output = 0.5 * (output_1 + output_2)
+
+ if hasattr(self, 'center_frame_only') and self.center_frame_only:
+ self.output = self.output[:, n // 2, :, :, :]
+
+ self.net_g.train()
diff --git a/basicsr/models/vqgan_model.py b/basicsr/models/vqgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b7734e7b2069f0122776b202971d6a154f182b
--- /dev/null
+++ b/basicsr/models/vqgan_model.py
@@ -0,0 +1,284 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+import torch.nn.functional as F
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class VQGANModel(SRModel):
+ def feed_data(self, data):
+ self.gt = data['gt'].to(self.device)
+ self.b = self.gt.shape[0]
+
+
+ def init_training_settings(self):
+ logger = get_root_logger()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
+
+ self.net_g.train()
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ if train_opt.get('codebook_opt'):
+ self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0)
+ else:
+ self.l_weight_codebook = 1.0
+
+ self.vqgan_quantizer = self.opt['network_g']['quantizer']
+ logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}')
+
+ self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
+ self.disc_weight = train_opt.get('disc_weight', 0.8)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
+ recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
+ return d_weight
+
+ def adopt_weight(self, weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_params_g = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params_g.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def optimize_parameters(self, current_iter):
+ logger = get_root_logger()
+ loss_dict = OrderedDict()
+ if self.opt['network_g']['quantizer'] == 'gumbel':
+ self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1)
+ if current_iter%1000 == 0:
+ logger.info(f'temperature: {self.net_g.module.quantize.temperature}')
+
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output, l_codebook, quant_stats = self.net_g(self.gt)
+
+ l_codebook = l_codebook*self.l_weight_codebook
+
+ l_g_total = 0
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep = self.cri_perceptual(self.output, self.gt)
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+
+ # gan loss
+ if current_iter > self.net_d_start_iter:
+ # fake_g_pred = self.net_d(self.output_1024)
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ recon_loss = l_g_total
+ last_layer = self.net_g.module.generator.blocks[-1].weight
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
+ d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter)
+ d_weight *= self.disc_weight # tamming setting 0.8
+ l_g_total += d_weight * l_g_gan
+ loss_dict['l_g_gan'] = d_weight * l_g_gan
+
+ l_g_total += l_codebook
+ loss_dict['l_codebook'] = l_codebook
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ if current_iter > self.net_d_start_iter:
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+
+ def test(self):
+ with torch.no_grad():
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ self.output, _, _ = self.net_g_ema(self.gt)
+ else:
+ logger = get_root_logger()
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
+ self.net_g.eval()
+ self.output, _, _ = self.net_g(self.gt)
+ self.net_g.train()
+
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ metric_data = dict(img1=sr_img, img2=gt_img)
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}\n'
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['gt'] = self.gt.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ return out_dict
+
+ def save(self, epoch, current_iter):
+ if self.ema_decay > 0:
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/basicsr/ops/__init__.py b/basicsr/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/basicsr/ops/dcn/__init__.py b/basicsr/ops/dcn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff
--- /dev/null
+++ b/basicsr/ops/dcn/__init__.py
@@ -0,0 +1,7 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
+ modulated_deform_conv)
+
+__all__ = [
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+ 'modulated_deform_conv'
+]
diff --git a/basicsr/ops/dcn/deform_conv.py b/basicsr/ops/dcn/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..734154f9ed9447d585eae7df6886acb136f8a3cf
--- /dev/null
+++ b/basicsr/ops/dcn/deform_conv.py
@@ -0,0 +1,377 @@
+import math
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair, _single
+
+try:
+ from . import deform_conv_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ deform_conv_ext = load(
+ 'deform_conv',
+ sources=[
+ os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
+ ],
+ )
+
+
+class DeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ im2col_step=64):
+ if input is not None and input.dim() != 4:
+ raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ deform_conv_ext.deform_conv_forward(input, weight,
+ offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+ cur_im2col_step)
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+ or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, \
+ f'in_channels {in_channels} is not divisible by groups {groups}'
+ assert out_channels % groups == 0, \
+ f'out_channels {out_channels} is not divisible ' \
+ f'by groups {groups}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+ return out
+
+
+class DeformConvPack(DeformConv):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConvPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
diff --git a/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/basicsr/ops/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5d9424908ed2dbd4ac3cdb98d13e09287a4d2f2d
--- /dev/null
+++ b/basicsr/ops/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,685 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im);
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const int channels, const int height,
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+ at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+ int padW, int dilationH, int dilationW, int group,
+ int deformable_group) {
+ TORCH_CHECK(weight.ndimension() == 4,
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+ "but got: %s",
+ weight.ndimension());
+
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ TORCH_CHECK(kW > 0 && kH > 0,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+ kW);
+
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+ "kernel size should be consistent with weight, ",
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+ kW, weight.size(2), weight.size(3));
+
+ TORCH_CHECK(dW > 0 && dH > 0,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+ TORCH_CHECK(
+ dilationW > 0 && dilationH > 0,
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+ dilationH, dilationW);
+
+ int ndim = input.ndimension();
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+ ndim);
+
+ long nInputPlane = weight.size(1) * group;
+ long inputHeight = input.size(dimh);
+ long inputWidth = input.size(dimw);
+ long nOutputPlane = weight.size(0);
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
+ "input channels must divide deformable group size");
+
+ if (outputWidth < 1 || outputHeight < 1)
+ AT_ERROR(
+ "Given input size: (%ld x %ld x %ld). "
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+ outputWidth);
+
+ TORCH_CHECK(input.size(1) == nInputPlane,
+ "invalid number of input planes, expected: %d, but got: %d",
+ nInputPlane, input.size(1));
+
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+ "input image is smaller than kernel");
+
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+ "invalid spatial size of offset, expected height: %d width: %d, but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+ "invalid number of channels of offset");
+
+ if (gradOutput != NULL) {
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
+ nOutputPlane, gradOutput->size(dimf));
+
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+ gradOutput->size(dimw) == outputWidth),
+ "invalid size of gradOutput, expected height: %d width: %d , but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, gradOutput->size(dimh),
+ gradOutput->size(dimw));
+ }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ // todo: resize columns to include im2col: done
+ // todo: add im2col_step as input
+ // todo: add new output buffer and transpose it to output (or directly
+ // transpose output) todo: possibly change data indexing because of
+ // parallel_imgs
+
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input.unsqueeze_(0);
+ offset.unsqueeze_(0);
+ }
+
+ // todo: assert batchsize dividable by im2col_step
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+ outputHeight, outputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+ ones = at::ones({outputHeight, outputWidth}, input.options());
+ }
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ at::Tensor output_buffer =
+ at::zeros({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth},
+ output.options());
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
+ output_buffer.size(2), output_buffer.size(3)});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ output_buffer[elt][g] = output_buffer[elt][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output_buffer[elt][g]);
+ }
+ }
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+ output_buffer.size(3), output_buffer.size(4)});
+
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step, outputHeight, outputWidth});
+ output_buffer.transpose_(1, 2);
+ output.copy_(output_buffer);
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ // change order of grad output
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight,
+ outputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ // divide into groups
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, im2col_step, deformable_group,
+ gradOffset[elt]);
+
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
+ }
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ gradOffset = gradOffset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ gradOffset =
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ // todo: transpose and reshape outGrad
+ // todo: reshape columns
+ // todo: add im2col_step as input
+
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+ padW, dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view(
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = gradWeight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+ outputHeight, outputWidth});
+ gradOutputBuffer.copy_(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth});
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ // divide into group
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ gradWeight =
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ gradWeight[g] = gradWeight[g]
+ .flatten(1)
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
+ columns[g].transpose(1, 0), 1.0, scale)
+ .view_as(gradWeight[g]);
+ }
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0),
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3),
+ gradWeight.size(4)});
+ }
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ }
+
+ return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ // resize output
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
+ // resize temporary columns
+ columns =
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+ input.options());
+
+ output = output.view({output.size(0), group, output.size(1) / group,
+ output.size(2), output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ // divide into group
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+ for (int g = 0; g < group; g++) {
+ output[b][g] = output[b][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output[b][g]);
+ }
+
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ }
+
+ output = output.view({output.size(0), output.size(1) * output.size(2),
+ output.size(3), output.size(4)});
+
+ if (with_bias) {
+ output += bias.view({1, bias.size(0), 1, 1});
+ }
+}
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ grad_input = grad_input.view({batch, channels, height, width});
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+ input.options());
+
+ grad_output =
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+ grad_output.size(2), grad_output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ // divide int group
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+ grad_mask[b]);
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
+ // group
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+ grad_weight.size(1), grad_weight.size(2),
+ grad_weight.size(3)});
+ if (with_bias)
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+ for (int g = 0; g < group; g++) {
+ grad_weight[g] =
+ grad_weight[g]
+ .flatten(1)
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+ .view_as(grad_weight[g]);
+ if (with_bias) {
+ grad_bias[g] =
+ grad_bias[g]
+ .view({-1, 1})
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+ .view(-1);
+ }
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+ grad_weight.size(2), grad_weight.size(3),
+ grad_weight.size(4)});
+ if (with_bias)
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+ }
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+ grad_output.size(2), grad_output.size(3),
+ grad_output.size(4)});
+}
diff --git a/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..98752dccf8c58817ca1a952554dd3f33188a2d34
--- /dev/null
+++ b/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
@@ -0,0 +1,867 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+
+inline int GET_BLOCKS(const int N)
+{
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+template
+__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const scalar_t map_h = i * dilation_h + offset_h;
+ //const scalar_t map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+void deformable_im2col(
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h, const int ksize_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ // todo: check parallel_imgs is correctly passed in
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *data_col_ = data_col.data_ptr();
+
+ deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+ height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_gpu_kernel(
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+ 2 * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index];
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+void deformable_col2im(
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im)
+{
+
+ // todo: make sure parallel_imgs is passed in correctly
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *grad_im_ = grad_im.data_ptr();
+
+ deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+ const scalar_t *data_im, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col, scalar_t *grad_offset)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+ batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ const scalar_t weight = get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos];
+ cnt += 1;
+ }
+
+ grad_offset[index] = val;
+ }
+}
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+ const int stride_w, const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+{
+
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *grad_offset_ = grad_offset.data_ptr();
+
+ deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+ height_col, width_col, grad_offset_);
+ }));
+}
+
+template
+__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const float map_h = i * dilation_h + offset_h;
+ //const float map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val * mask;
+ data_col_ptr += batch_size * height_col * width_col;
+ //data_col_ptr += height_col * width_col;
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index] * mask;
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_im,
+ const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_offset, scalar_t *grad_mask)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0, mval = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ else
+ {
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+ }
+ const scalar_t weight = dmcn_get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos] * mask;
+ cnt += 1;
+ }
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+ grad_offset[index] = val;
+ if (offset_c % 2 == 0)
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+ }
+}
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *data_col_ = data_col.data_ptr();
+
+ modulated_deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor grad_im)
+{
+
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *grad_im_ = grad_im.data_ptr();
+
+ modulated_deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group,
+ at::Tensor grad_offset, at::Tensor grad_mask)
+{
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *grad_offset_ = grad_offset.data_ptr();
+ scalar_t *grad_mask_ = grad_mask.data_ptr();
+
+ modulated_deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+ grad_offset_, grad_mask_);
+ }));
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
diff --git a/basicsr/ops/dcn/src/deform_conv_ext.cpp b/basicsr/ops/dcn/src/deform_conv_ext.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..41c6df6f721bd95a525fd6a03dd9882e863de042
--- /dev/null
+++ b/basicsr/ops/dcn/src/deform_conv_ext.cpp
@@ -0,0 +1,164 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+#define WITH_CUDA // always use cuda
+#ifdef WITH_CUDA
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step);
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step);
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step);
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias);
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias);
+#endif
+
+int deform_conv_forward(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_forward_cuda(input, weight, offset, output, columns,
+ ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
+ deformable_group, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_backward_input_cuda(input, offset, gradOutput,
+ gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
+ dilationW, dilationH, group, deformable_group, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_parameters(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
+ gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
+ dilationH, group, deformable_group, scale, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
+ offset, mask, output, columns, kernel_h, kernel_w, stride_h,
+ stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
+ deformable_group, with_bias);
+#else
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
+ offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
+ grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
+ pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
+ with_bias);
+#else
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_conv_forward", &deform_conv_forward,
+ "deform forward");
+ m.def("deform_conv_backward_input", &deform_conv_backward_input,
+ "deform_conv_backward_input");
+ m.def("deform_conv_backward_parameters",
+ &deform_conv_backward_parameters,
+ "deform_conv_backward_parameters");
+ m.def("modulated_deform_conv_forward",
+ &modulated_deform_conv_forward,
+ "modulated deform conv forward");
+ m.def("modulated_deform_conv_backward",
+ &modulated_deform_conv_backward,
+ "modulated deform conv backward");
+}
diff --git a/basicsr/ops/fused_act/__init__.py b/basicsr/ops/fused_act/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422
--- /dev/null
+++ b/basicsr/ops/fused_act/__init__.py
@@ -0,0 +1,3 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+
+__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
diff --git a/basicsr/ops/fused_act/fused_act.py b/basicsr/ops/fused_act/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..588f815e596ab0fc83ab0f9d21426c22ec5ed7c3
--- /dev/null
+++ b/basicsr/ops/fused_act/fused_act.py
@@ -0,0 +1,89 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+import torch
+from torch import nn
+from torch.autograd import Function
+
+try:
+ from . import fused_act_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ fused_act_ext = load(
+ 'fused',
+ sources=[
+ os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
+ os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
+ ],
+ )
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
+ ctx.scale)
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/basicsr/ops/fused_act/src/fused_bias_act.cpp b/basicsr/ops/fused_act/src/fused_bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..85ed0a79fb9c75f83470ac834090f03608d998ee
--- /dev/null
+++ b/basicsr/ops/fused_act/src/fused_bias_act.cpp
@@ -0,0 +1,26 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
diff --git a/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..54c7ff53ce8306db2b3c582ec7fa6696a38b4df0
--- /dev/null
+++ b/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
@@ -0,0 +1,100 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
diff --git a/basicsr/ops/upfirdn2d/__init__.py b/basicsr/ops/upfirdn2d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd
--- /dev/null
+++ b/basicsr/ops/upfirdn2d/__init__.py
@@ -0,0 +1,3 @@
+from .upfirdn2d import upfirdn2d
+
+__all__ = ['upfirdn2d']
diff --git a/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..43d0b6783a5b512b55815a291fcac2bebeea31e0
--- /dev/null
+++ b/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
@@ -0,0 +1,24 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
diff --git a/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8870063bae4468deab2e721f0978fe9facfb01b1
--- /dev/null
+++ b/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
@@ -0,0 +1,370 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+template
+__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ int out_y = minor_idx / p.minor_dim;
+ minor_idx -= out_y * p.minor_dim;
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major && major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, out_x = out_x_base;
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
+
+ const scalar_t *x_p =
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
+ minor_idx];
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
+ int x_px = p.minor_dim;
+ int k_px = -p.up_x;
+ int x_py = p.in_w * p.minor_dim;
+ int k_py = -p.up_y * p.kernel_w;
+
+ scalar_t v = 0.0f;
+
+ for (int y = 0; y < h; y++) {
+ for (int x = 0; x < w; x++) {
+ v += static_cast(*x_p) * static_cast(*k_p);
+ x_p += x_px;
+ k_p += k_px;
+ }
+
+ x_p += x_py - w * x_px;
+ k_p += k_py - w * k_px;
+ }
+
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+}
+
+template
+__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
+ tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major & major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
+ loop_x < p.loop_x & tile_out_x < p.out_w;
+ loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
+ in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
+ p.minor_dim +
+ minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
+ out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+#pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+#pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] *
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+torch::Tensor upfirdn2d_op(const torch::Tensor &input,
+ const torch::Tensor &kernel, int up_x, int up_y,
+ int down_x, int down_y, int pad_x0, int pad_x1,
+ int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
+ p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
+ p.down_x;
+
+ auto out =
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h = -1;
+ int tile_out_w = -1;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w > 0) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ } else {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 4;
+ block_size = dim3(4, 32, 1);
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 3:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 4:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 5:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 6:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ default:
+ upfirdn2d_kernel_large<<>>(
+ out.data_ptr(), x.data_ptr(),
+ k.data_ptr(), p);
+ }
+ });
+
+ return out;
+}
diff --git a/basicsr/ops/upfirdn2d/upfirdn2d.py b/basicsr/ops/upfirdn2d/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..667f96e1ded35d48f163f37e21d1ed8ff191aac3
--- /dev/null
+++ b/basicsr/ops/upfirdn2d/upfirdn2d.py
@@ -0,0 +1,186 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+try:
+ from . import upfirdn2d_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ upfirdn2d_ext = load(
+ 'upfirdn2d',
+ sources=[
+ os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
+ os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
+ ],
+ )
+
+
+class UpFirDn2dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ if input.device.type == 'cpu':
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+ else:
+ out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
+
+ return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/basicsr/test.py b/basicsr/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d200bc90b2fb42f4aa58202eb4251d46c943c76e
--- /dev/null
+++ b/basicsr/test.py
@@ -0,0 +1,49 @@
+import logging
+import torch
+from os import path as osp
+
+from basicsr.data import build_dataloader, build_dataset
+from basicsr.models import build_model
+from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs
+from basicsr.utils.options import dict2str, parse_options
+
+
+def test_pipeline(root_path):
+ # parse options, set distributed setting, set ramdom seed
+ opt, _ = parse_options(root_path, is_train=False)
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ # mkdir and initialize loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'],
+ f"test_{opt['name']}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr',
+ log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(opt))
+
+ # create test dataset and dataloader
+ test_loaders = []
+ for _, dataset_opt in sorted(opt['datasets'].items()):
+ test_set = build_dataset(dataset_opt)
+ test_loader = build_dataloader(
+ test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+ logger.info(
+ f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
+ test_loaders.append(test_loader)
+
+ # create model
+ model = build_model(opt)
+
+ for test_loader in test_loaders:
+ test_set_name = test_loader.dataset.opt['name']
+ logger.info(f'Testing {test_set_name}...')
+ model.validation(
+ test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
+
+
+if __name__ == '__main__':
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ test_pipeline(root_path)
diff --git a/basicsr/train.py b/basicsr/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d588e8370dd8458229c49f5182e2cecddd081d2
--- /dev/null
+++ b/basicsr/train.py
@@ -0,0 +1,242 @@
+import warnings
+from basicsr.utils.options import dict2str, parse
+from basicsr.utils.dist_util import get_dist_info, init_dist
+from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
+ init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed)
+from basicsr.models import build_model
+from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
+from basicsr.data.data_sampler import EnlargedSampler
+from basicsr.data import build_dataloader, build_dataset
+import argparse
+import datetime
+import logging
+import math
+import copy
+import random
+import time
+import pdb
+import torch
+from os import path as osp
+import os
+# os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
+
+
+# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+def parse_options(root_path, is_train=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, required=True,
+ help='Path to option YAML file.')
+ parser.add_argument(
+ '--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ opt = parse(args.opt, root_path, is_train=is_train)
+
+ # distributed settings
+ if args.launcher == 'none':
+ opt['dist'] = False
+ print('Disable distributed.', flush=True)
+ else:
+ opt['dist'] = True
+ if args.launcher == 'slurm' and 'dist_params' in opt:
+ init_dist(args.launcher, **opt['dist_params'])
+ else:
+ init_dist(args.launcher)
+
+ opt['rank'], opt['world_size'] = get_dist_info()
+
+ # random seed
+ seed = opt.get('manual_seed')
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt['manual_seed'] = seed
+ set_random_seed(seed + opt['rank'])
+
+ return opt
+
+
+def init_loggers(opt):
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+ logger = get_root_logger(logger_name='basicsr',
+ log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(opt))
+
+ # initialize wandb logger before tensorboard logger to allow proper sync:
+ if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
+ assert opt['logger'].get(
+ 'use_tb_logger') is True, ('should turn on tensorboard when using wandb')
+ init_wandb_logger(opt)
+ tb_logger = None
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
+ tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
+ return logger, tb_logger
+
+
+def create_train_val_dataloader(opt, logger):
+ # create train and val dataloaders
+ train_loader, val_loader = None, None
+ for phase, dataset_opt in opt['datasets'].items():
+ if phase == 'train':
+ dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
+ train_set = build_dataset(dataset_opt)
+ train_sampler = EnlargedSampler(
+ train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
+ train_loader = build_dataloader(
+ train_set,
+ dataset_opt,
+ num_gpu=opt['num_gpu'],
+ dist=opt['dist'],
+ sampler=train_sampler,
+ seed=opt['manual_seed'])
+
+ num_iter_per_epoch = math.ceil(
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
+ total_iters = int(opt['train']['total_iter'])
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
+ logger.info('Training statistics:'
+ f'\n\tNumber of train images: {len(train_set)}'
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
+
+ elif phase == 'val':
+ val_set = build_dataset(dataset_opt)
+ val_loader = build_dataloader(
+ val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+ logger.info(
+ f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
+ else:
+ raise ValueError(f'Dataset phase {phase} is not recognized.')
+
+ return train_loader, train_sampler, val_loader, total_epochs, total_iters
+
+
+def train_pipeline(root_path):
+ # parse options, set distributed setting, set ramdom seed
+ opt = parse_options(root_path, is_train=True)
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ # load resume states if necessary
+ if opt['path'].get('resume_state'):
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(
+ opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ resume_state = None
+
+ # mkdir for experiments and logger
+ if resume_state is None:
+ make_exp_dirs(opt)
+ if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
+ mkdir_and_rename(osp.join('tb_logger', opt['name']))
+
+ # initialize loggers
+ logger, tb_logger = init_loggers(opt)
+
+ # create train and validation dataloaders
+ result = create_train_val_dataloader(opt, logger)
+ train_loader, train_sampler, val_loader, total_epochs, total_iters = result
+
+ # create model
+ if resume_state: # resume training
+ check_resume(opt, resume_state['iter'])
+ model = build_model(opt)
+ model.resume_training(resume_state) # handle optimizers and schedulers
+ logger.info(
+ f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
+ start_epoch = resume_state['epoch']
+ current_iter = resume_state['iter']
+ else:
+ model = build_model(opt)
+ start_epoch = 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ # dataloader prefetcher
+ prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
+ if prefetch_mode is None or prefetch_mode == 'cpu':
+ prefetcher = CPUPrefetcher(train_loader)
+ elif prefetch_mode == 'cuda':
+ prefetcher = CUDAPrefetcher(train_loader, opt)
+ logger.info(f'Use {prefetch_mode} prefetch dataloader')
+ if opt['datasets']['train'].get('pin_memory') is not True:
+ raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
+ else:
+ raise ValueError(
+ f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
+
+ # training
+ logger.info(
+ f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
+ data_time, iter_time = time.time(), time.time()
+ start_time = time.time()
+
+ for epoch in range(start_epoch, total_epochs + 1):
+ train_sampler.set_epoch(epoch)
+ prefetcher.reset()
+ train_data = prefetcher.next()
+
+ while train_data is not None:
+ data_time = time.time() - data_time
+
+ current_iter += 1
+ if current_iter > total_iters:
+ break
+ # update learning rate
+ model.update_learning_rate(
+ current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
+ # training
+ model.feed_data(train_data)
+ model.optimize_parameters(current_iter)
+ iter_time = time.time() - iter_time
+ # log
+ if current_iter % opt['logger']['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': model.get_current_learning_rate()})
+ log_vars.update({'time': iter_time, 'data_time': data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ # save models and training states
+ if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
+ logger.info('Saving models and training states.')
+ model.save(epoch, current_iter)
+
+ # validation
+ if opt.get('val') is not None and opt['datasets'].get('val') is not None \
+ and (current_iter % opt['val']['val_freq'] == 0):
+ model.validation(val_loader, current_iter,
+ tb_logger, opt['val']['save_img'])
+
+ data_time = time.time()
+ iter_time = time.time()
+ train_data = prefetcher.next()
+ # end of iter
+
+ # end of epoch
+
+ consumed_time = str(datetime.timedelta(
+ seconds=int(time.time() - start_time)))
+ logger.info(f'End of training. Time consumed: {consumed_time}')
+ logger.info('Save the latest model.')
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
+ if opt.get('val') is not None and opt['datasets'].get('val'):
+ model.validation(val_loader, current_iter,
+ tb_logger, opt['val']['save_img'])
+ if tb_logger:
+ tb_logger.close()
+
+
+if __name__ == '__main__':
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ train_pipeline(root_path)
diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fcc1d540462712387523d1e326d1dfc2bcfbf32
--- /dev/null
+++ b/basicsr/utils/__init__.py
@@ -0,0 +1,29 @@
+from .file_client import FileClient
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
+from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
+from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
+
+__all__ = [
+ # file_client.py
+ 'FileClient',
+ # img_util.py
+ 'img2tensor',
+ 'tensor2img',
+ 'imfrombytes',
+ 'imwrite',
+ 'crop_border',
+ # logger.py
+ 'MessageLogger',
+ 'init_tb_logger',
+ 'init_wandb_logger',
+ 'get_root_logger',
+ 'get_env_info',
+ # misc.py
+ 'set_random_seed',
+ 'get_time_str',
+ 'mkdir_and_rename',
+ 'make_exp_dirs',
+ 'scandir',
+ 'check_resume',
+ 'sizeof_fmt'
+]
diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0
--- /dev/null
+++ b/basicsr/utils/dist_util.py
@@ -0,0 +1,82 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/basicsr/utils/download_util.py b/basicsr/utils/download_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a267915743ee3f3232bc8fe992466b52468979a
--- /dev/null
+++ b/basicsr/utils/download_util.py
@@ -0,0 +1,95 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+from .misc import sizeof_fmt
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+ Ref:
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = 'https://docs.google.com/uc?export=download'
+ params = {'id': file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params['confirm'] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+ print(response_file_size)
+ if 'Content-Range' in response_file_size.headers:
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+ return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+ Returns:
+ str: The path to the downloaded file.
+ """
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
\ No newline at end of file
diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f38d9796da3899048924f2f803d1088927966b0
--- /dev/null
+++ b/basicsr/utils/file_client.py
@@ -0,0 +1,167 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'r') as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths):
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing differnet lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient(object):
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ }
+
+ def __init__(self, backend='disk', **kwargs):
+ if backend not in self._backends:
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(self._backends.keys())}')
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key='default'):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == 'lmdb':
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7180b4e9b5c8f2eb36a9a0e4ff6affdaae84b8
--- /dev/null
+++ b/basicsr/utils/flow_util.py
@@ -0,0 +1,170 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
+import cv2
+import numpy as np
+import os
+
+
+def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+
+ Args:
+ flow_path (ndarray or str): Flow path.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if quantize:
+ assert concat_axis in [0, 1]
+ cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
+ if cat_flow.ndim != 2:
+ raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+ else:
+ with open(flow_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
+
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+
+ return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ flow = flow.astype(np.float32)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ cv2.imwrite(filename, dxdy)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
+ return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
+
+ return dequantized_arr
diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aba82ce08eefaeb3e56ea5a3a09c342ae513522
--- /dev/null
+++ b/basicsr/utils/img_util.py
@@ -0,0 +1,171 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+ """This implementation is slightly faster than tensor2img.
+ It now only supports torch tensor with shape (1, c, h, w).
+
+ Args:
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+ min_max (tuple[int]): min and max values for clamp.
+ """
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+ output = output.type(torch.uint8).cpu().numpy()
+ if rgb2bgr:
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+ return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def crop_border(imgs, crop_border):
+ """Crop borders of images.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+ crop_border (int): Crop border for each end of height and weight.
+
+ Returns:
+ list[ndarray]: Cropped images.
+ """
+ if crop_border == 0:
+ return imgs
+ else:
+ if isinstance(imgs, list):
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+ else:
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
+
\ No newline at end of file
diff --git a/basicsr/utils/lmdb_util.py b/basicsr/utils/lmdb_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0a10f60ffca2e36ac5f5564aafd70e79d06a723
--- /dev/null
+++ b/basicsr/utils/lmdb_util.py
@@ -0,0 +1,196 @@
+import cv2
+import lmdb
+import sys
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(data_path,
+ lmdb_path,
+ img_path_list,
+ keys,
+ batch=5000,
+ compress_level=1,
+ multiprocessing_read=False,
+ n_thread=40,
+ map_size=None):
+ """Make lmdb from images.
+
+ Contents of lmdb. The file structure is:
+ example.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records 1)image name (with extension),
+ 2)image shape, and 3)compression level, separated by a white space.
+
+ For example, the meta information could be:
+ `000_00000000.png (720,1280,3) 1`, which means:
+ 1) image name (with extension): 000_00000000.png;
+ 2) image shape: (720,1280,3);
+ 3) compression level: 1
+
+ We use the image name without extension as the lmdb key.
+
+ If `multiprocessing_read` is True, it will read all the images to memory
+ using multiprocessing. Thus, your server needs to have enough memory.
+
+ Args:
+ data_path (str): Data path for reading images.
+ lmdb_path (str): Lmdb save path.
+ img_path_list (str): Image path list.
+ keys (str): Used for lmdb keys.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ multiprocessing_read (bool): Whether use multiprocessing to read all
+ the images to memory. Default: False.
+ n_thread (int): For multiprocessing.
+ map_size (int | None): Map size for lmdb env. If None, use the
+ estimated size from images. Default: None
+ """
+
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
+ f'but got {len(img_path_list)} and {len(keys)}')
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
+ print(f'Totoal images: {len(img_path_list)}')
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ if multiprocessing_read:
+ # read all the images to memory (multiprocessing)
+ dataset = {} # use dict to keep the order for multiprocessing
+ shapes = {}
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
+ pbar = tqdm(total=len(img_path_list), unit='image')
+
+ def callback(arg):
+ """get the image data and update pbar."""
+ key, dataset[key], shapes[key] = arg
+ pbar.update(1)
+ pbar.set_description(f'Read {key}')
+
+ pool = Pool(n_thread)
+ for path, key in zip(img_path_list, keys):
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+ pool.close()
+ pool.join()
+ pbar.close()
+ print(f'Finish reading {len(img_path_list)} images.')
+
+ # create lmdb environment
+ if map_size is None:
+ # obtain data size for one image
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ data_size_per_img = img_byte.nbytes
+ print('Data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(img_path_list)
+ map_size = data_size * 10
+
+ env = lmdb.open(lmdb_path, map_size=map_size)
+
+ # write data to lmdb
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
+ txn = env.begin(write=True)
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
+ pbar.update(1)
+ pbar.set_description(f'Write {key}')
+ key_byte = key.encode('ascii')
+ if multiprocessing_read:
+ img_byte = dataset[key]
+ h, w, c = shapes[key]
+ else:
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+ h, w, c = img_shape
+
+ txn.put(key_byte, img_byte)
+ # write meta information
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
+ if idx % batch == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ pbar.close()
+ txn.commit()
+ env.close()
+ txt_file.close()
+ print('\nFinish writing lmdb.')
+
+
+def read_img_worker(path, key, compress_level):
+ """Read image worker.
+
+ Args:
+ path (str): Image path.
+ key (str): Image key.
+ compress_level (int): Compress level when encoding images.
+
+ Returns:
+ str: Image key.
+ byte: Image byte.
+ tuple[int]: Image shape.
+ """
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if img.ndim == 2:
+ h, w = img.shape
+ c = 1
+ else:
+ h, w, c = img.shape
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker():
+ """LMDB Maker.
+
+ Args:
+ lmdb_path (str): Lmdb save path.
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ self.lmdb_path = lmdb_path
+ self.batch = batch
+ self.compress_level = compress_level
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
+ self.txn = self.env.begin(write=True)
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ self.counter = 0
+
+ def put(self, img_byte, key, img_shape):
+ self.counter += 1
+ key_byte = key.encode('ascii')
+ self.txn.put(key_byte, img_byte)
+ # write meta information
+ h, w, c = img_shape
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
+ if self.counter % self.batch == 0:
+ self.txn.commit()
+ self.txn = self.env.begin(write=True)
+
+ def close(self):
+ self.txn.commit()
+ self.env.close()
+ self.txt_file.close()
diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cce46dd670fe989faf52e818a0ff2774eca1b64
--- /dev/null
+++ b/basicsr/utils/logger.py
@@ -0,0 +1,172 @@
+import datetime
+import logging
+import time
+
+from .dist_util import get_dist_info, master_only
+
+initialized_logger = {}
+
+
+class MessageLogger():
+ """Message logger for printing.
+ Args:
+ opt (dict): Config. It contains the following keys:
+ name (str): Exp name.
+ logger (dict): Contains 'print_freq' (str) for logger interval.
+ train (dict): Contains 'total_iter' (int) for total iters.
+ use_tb_logger (bool): Use tensorboard logger.
+ start_iter (int): Start iter. Default: 1.
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+ """
+
+ def __init__(self, opt, start_iter=1, tb_logger=None):
+ self.exp_name = opt['name']
+ self.interval = opt['logger']['print_freq']
+ self.start_iter = start_iter
+ self.max_iters = opt['train']['total_iter']
+ self.use_tb_logger = opt['logger']['use_tb_logger']
+ self.tb_logger = tb_logger
+ self.start_time = time.time()
+ self.logger = get_root_logger()
+
+ @master_only
+ def __call__(self, log_vars):
+ """Format logging message.
+ Args:
+ log_vars (dict): It contains the following keys:
+ epoch (int): Epoch number.
+ iter (int): Current iter.
+ lrs (list): List for learning rates.
+ time (float): Iter time.
+ data_time (float): Data time for each iter.
+ """
+ # epoch, iter, learning rates
+ epoch = log_vars.pop('epoch')
+ current_iter = log_vars.pop('iter')
+ lrs = log_vars.pop('lrs')
+
+ message = (
+ f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
+ for v in lrs:
+ message += f'{v:.3e},'
+ message += ')] '
+
+ # time and estimated time
+ if 'time' in log_vars.keys():
+ iter_time = log_vars.pop('time')
+ data_time = log_vars.pop('data_time')
+
+ total_time = time.time() - self.start_time
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ message += f'[eta: {eta_str}, '
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
+
+ # other items, especially losses
+ for k, v in log_vars.items():
+ message += f'{k}: {v:.4e} '
+ # tensorboard logger
+ if self.use_tb_logger and 'debug' not in self.exp_name:
+ # if k.startswith('l_'):
+ # self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
+ # else:
+ self.tb_logger.add_scalar(k, v, current_iter)
+ self.logger.info(message)
+
+
+@master_only
+def init_tb_logger(log_dir):
+ from torch.utils.tensorboard import SummaryWriter
+ tb_logger = SummaryWriter(log_dir=log_dir)
+ return tb_logger
+
+
+@master_only
+def init_wandb_logger(opt):
+ """We now only use wandb to sync tensorboard log."""
+ import wandb
+ logger = logging.getLogger('basicsr')
+
+ project = opt['logger']['wandb']['project']
+ resume_id = opt['logger']['wandb'].get('resume_id')
+ if resume_id:
+ wandb_id = resume_id
+ resume = 'allow'
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
+ else:
+ wandb_id = wandb.util.generate_id()
+ resume = 'never'
+
+ wandb.init(id=wandb_id, resume=resume,
+ name=opt['name'], config=opt, project=project, sync_tensorboard=True)
+
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
+
+
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+ Args:
+ logger_name (str): root logger name. Default: 'basicsr'.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger_name in initialized_logger:
+ return logger
+
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(logging.Formatter(format_str))
+ logger.addHandler(stream_handler)
+ logger.propagate = False
+ rank, _ = get_dist_info()
+ if rank != 0:
+ logger.setLevel('ERROR')
+ elif log_file is not None:
+ logger.setLevel(log_level)
+ # add file handler
+ # file_handler = logging.FileHandler(log_file, 'w')
+ # Shangchen: keep the previous log
+ file_handler = logging.FileHandler(log_file, 'a')
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+ initialized_logger[logger_name] = True
+ return logger
+
+
+def get_env_info():
+ """Get environment information.
+ Currently, only log the software version.
+ """
+ import torch
+ import torchvision
+
+ from basicsr.version import __version__
+ msg = r"""
+ ____ _ _____ ____
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
+ ______ __ __ __ __
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
+ """
+ msg += ('\nVersion Information: '
+ f'\n\tBasicSR: {__version__}'
+ f'\n\tPyTorch: {torch.__version__}'
+ f'\n\tTorchVision: {torchvision.__version__}')
+ return msg
diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6ce1004a2c9f8521505c4b5889d3c24a909c70d
--- /dev/null
+++ b/basicsr/utils/matlab_functions.py
@@ -0,0 +1,347 @@
+import math
+import numpy as np
+import torch
+
+
+def cubic(x):
+ """cubic function used for calculate_weights_indices."""
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
+ (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ """Calculate weights and indices, used for imresize function.
+
+ Args:
+ in_length (int): Input length.
+ out_length (int): Output length.
+ scale (float): Scale factor.
+ kernel_width (int): Kernel width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ """
+
+ if (scale < 1) and antialiasing:
+ # Use a modified kernel (larger kernel width) to simultaneously
+ # interpolate and antialias
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ p = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
+ out_length, p)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+ # apply cubic kernel
+ if (scale < 1) and antialiasing:
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, p)
+
+ # If a column in weights is all zero, get rid of it. only consider the
+ # first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, p - 2)
+ weights = weights.narrow(1, 1, p - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, p - 2)
+ weights = weights.narrow(1, 0, p - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+ """imresize function same as MATLAB.
+
+ It now only supports bicubic.
+ The same scale applies for both height and width.
+
+ Args:
+ img (Tensor | Numpy array):
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
+ scale (float): Scale factor. The same scale applies for both height
+ and width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ Default: True.
+
+ Returns:
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+ """
+ if type(img).__module__ == np.__name__: # numpy type
+ numpy_type = True
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+ else:
+ numpy_type = False
+
+ in_c, in_h, in_w = img.size()
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # get weights and indices
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
+ antialiasing)
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
+ antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+ sym_patch = img[:, :sym_len_hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_he:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
+ kernel_width = weights_h.size(1)
+ for i in range(out_h):
+ idx = int(indices_h[i][0])
+ for j in range(in_c):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_we:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
+ kernel_width = weights_w.size(1)
+ for i in range(out_w):
+ idx = int(indices_w[i][0])
+ for j in range(in_c):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
+
+ if numpy_type:
+ out_2 = out_2.numpy().transpose(1, 2, 0)
+ return out_2
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace convertion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d62d8f1db757e803575e6c1b74f4b54a679790c
--- /dev/null
+++ b/basicsr/utils/misc.py
@@ -0,0 +1,161 @@
+import os
+import re
+import random
+import time
+import torch
+import numpy as np
+from os import path as osp
+
+from .dist_util import master_only
+from .logger import get_root_logger
+
+IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",
+ torch.__version__)[0][:3])] >= [1, 12, 0]
+
+
+def gpu_is_available():
+ if IS_HIGH_VERSION:
+ if torch.backends.mps.is_available():
+ return True
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
+
+
+def get_device(gpu_id=None):
+ if gpu_id is None:
+ gpu_str = ''
+ elif isinstance(gpu_id, int):
+ gpu_str = f':{gpu_id}'
+ else:
+ raise TypeError('Input should be int value.')
+
+ if IS_HIGH_VERSION:
+ if torch.backends.mps.is_available():
+ return torch.device('mps'+gpu_str)
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+ """Make dirs for experiments."""
+ path_opt = opt['path'].copy()
+ if opt['is_train']:
+ mkdir_and_rename(path_opt.pop('experiments_root'))
+ else:
+ mkdir_and_rename(path_opt.pop('results_root'))
+ for key, path in path_opt.items():
+ if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key) and ('param_key' not in key):
+ os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative pathes.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+ """Check resume states and pretrain_network paths.
+
+ Args:
+ opt (dict): Options.
+ resume_iter (int): Resume iteration.
+ """
+ logger = get_root_logger()
+ if opt['path']['resume_state']:
+ # get all the networks
+ networks = [key for key in opt.keys() if key.startswith('network_')]
+ flag_pretrain = False
+ for network in networks:
+ if opt['path'].get(f'pretrain_{network}') is not None:
+ flag_pretrain = True
+ if flag_pretrain:
+ logger.warning(
+ 'pretrain_network path will be ignored during resuming.')
+ # set pretrained model paths
+ for network in networks:
+ name = f'pretrain_{network}'
+ basename = network.replace('network_', '')
+ if opt['path'].get('ignore_resume_networks') is None or (basename
+ not in opt['path']['ignore_resume_networks']):
+ opt['path'][name] = osp.join(
+ opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
+ logger.info(f"Set {name} to {opt['path'][name]}")
+
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formated file siz.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf13cbc0b944a8c3577862dc9148d81040063c71
--- /dev/null
+++ b/basicsr/utils/options.py
@@ -0,0 +1,294 @@
+import yaml
+import time
+import os
+from collections import OrderedDict
+from os import path as osp
+from basicsr.utils.misc import get_time_str
+import argparse
+import random
+import torch
+from collections import OrderedDict
+
+from basicsr.utils import set_random_seed
+from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
+
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+def yaml_load(f):
+ """Load yaml file or string.
+
+ Args:
+ f (str): File path or a python string.
+
+ Returns:
+ dict: Loaded dict.
+ """
+ if os.path.isfile(f):
+ with open(f, 'r') as f:
+ return yaml.load(f, Loader=ordered_yaml()[0])
+ else:
+ return yaml.load(f, Loader=ordered_yaml()[0])
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = '\n'
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_level * 2) + k + ':['
+ msg += dict2str(v, indent_level + 1)
+ msg += ' ' * (indent_level * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+ return msg
+
+
+def _postprocess_yml_value(value):
+ # None
+ if value == '~' or value.lower() == 'none':
+ return None
+ # bool
+ if value.lower() == 'true':
+ return True
+ elif value.lower() == 'false':
+ return False
+ # !!float number
+ if value.startswith('!!float'):
+ return float(value.replace('!!float', ''))
+ # number
+ if value.isdigit():
+ return int(value)
+ elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
+ return float(value)
+ # list
+ if value.startswith('['):
+ return eval(value)
+ # str
+ return value
+
+
+def parse_options(root_path, is_train=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, required=True,
+ help='Path to option YAML file.')
+ parser.add_argument(
+ '--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+ parser.add_argument('--auto_resume', action='store_true')
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument(
+ '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
+ args = parser.parse_args()
+
+ # parse yml to dict
+ opt = yaml_load(args.opt)
+
+ # distributed settings
+ if args.launcher == 'none':
+ opt['dist'] = False
+ print('Disable distributed.', flush=True)
+ else:
+ opt['dist'] = True
+ if args.launcher == 'slurm' and 'dist_params' in opt:
+ init_dist(args.launcher, **opt['dist_params'])
+ else:
+ init_dist(args.launcher)
+ opt['rank'], opt['world_size'] = get_dist_info()
+
+ # random seed
+ seed = opt.get('manual_seed')
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt['manual_seed'] = seed
+ set_random_seed(seed + opt['rank'])
+
+ # force to update yml options
+ if args.force_yml is not None:
+ for entry in args.force_yml:
+ # now do not support creating new keys
+ keys, value = entry.split('=')
+ keys, value = keys.strip(), value.strip()
+ value = _postprocess_yml_value(value)
+ eval_str = 'opt'
+ for key in keys.split(':'):
+ eval_str += f'["{key}"]'
+ eval_str += '=value'
+ # using exec function
+ exec(eval_str)
+
+ opt['auto_resume'] = args.auto_resume
+ opt['is_train'] = is_train
+
+ # debug setting
+ if args.debug and not opt['name'].startswith('debug'):
+ opt['name'] = 'debug_' + opt['name']
+
+ if opt['num_gpu'] == 'auto':
+ opt['num_gpu'] = torch.cuda.device_count()
+
+ # datasets
+ for phase, dataset in opt['datasets'].items():
+ # for multiple datasets, e.g., val_1, val_2; test_1, test_2
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ if 'scale' in opt:
+ dataset['scale'] = opt['scale']
+ if dataset.get('dataroot_gt') is not None:
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+ if dataset.get('dataroot_lq') is not None:
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+ # paths
+ for key, val in opt['path'].items():
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+ opt['path'][key] = osp.expanduser(val)
+
+ if is_train:
+ experiments_root = opt['path'].get('experiments_root')
+ if experiments_root is None:
+ experiments_root = osp.join(root_path, 'experiments')
+ experiments_root = osp.join(experiments_root, opt['name'])
+
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['training_states'] = osp.join(
+ experiments_root, 'training_states')
+ opt['path']['log'] = experiments_root
+ opt['path']['visualization'] = osp.join(
+ experiments_root, 'visualization')
+
+ # change some options for debug mode
+ if 'debug' in opt['name']:
+ if 'val' in opt:
+ opt['val']['val_freq'] = 8
+ opt['logger']['print_freq'] = 1
+ opt['logger']['save_checkpoint_freq'] = 8
+ else: # test
+ results_root = opt['path'].get('results_root')
+ if results_root is None:
+ results_root = osp.join(root_path, 'results')
+ results_root = osp.join(results_root, opt['name'])
+
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+ return opt, args
+
+
+def parse(opt_path, root_path, is_train=True):
+ """Parse option file.
+
+ Args:
+ opt_path (str): Option file path.
+ is_train (str): Indicate whether in training or not. Default: True.
+
+ Returns:
+ (dict): Options.
+ """
+ with open(opt_path, mode='r') as f:
+ Loader, _ = ordered_yaml()
+ opt = yaml.load(f, Loader=Loader)
+
+ opt['is_train'] = is_train
+
+ # if opt['path'].get('resume_state', None): # Shangchen added
+ # resume_state_path = opt['path'].get('resume_state')
+ # opt['name'] = resume_state_path.split("/")[-3]
+ # else:
+ # opt['name'] = f"{get_time_str()}_{opt['name']}"
+
+ # datasets
+ for phase, dataset in opt['datasets'].items():
+ # for several datasets, e.g., test_1, test_2
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ if 'scale' in opt:
+ dataset['scale'] = opt['scale']
+ if dataset.get('dataroot_gt') is not None:
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+ if dataset.get('dataroot_lq') is not None:
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+ # paths
+ for key, val in opt['path'].items():
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+ opt['path'][key] = osp.expanduser(val)
+
+ if is_train:
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['training_states'] = osp.join(
+ experiments_root, 'training_states')
+ opt['path']['log'] = experiments_root
+ opt['path']['visualization'] = osp.join(
+ experiments_root, 'visualization')
+
+ # change some options for debug mode
+ if 'debug' in opt['name']:
+ if 'val' in opt:
+ opt['val']['val_freq'] = 8
+ opt['logger']['print_freq'] = 1
+ opt['logger']['save_checkpoint_freq'] = 8
+
+ else: # test
+ results_root = osp.join(root_path, 'results', opt['name'])
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+ return opt
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = '\n'
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_level * 2) + k + ':['
+ msg += dict2str(v, indent_level + 1)
+ msg += ' ' * (indent_level * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+ return msg
diff --git a/basicsr/utils/realesrgan_utils.py b/basicsr/utils/realesrgan_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..466866b8cca82a38723fdd2d5f01ca024374ac05
--- /dev/null
+++ b/basicsr/utils/realesrgan_utils.py
@@ -0,0 +1,305 @@
+import cv2
+import math
+import numpy as np
+import os
+import queue
+import threading
+import torch
+from torch.nn import functional as F
+from basicsr.utils.download_util import load_file_from_url
+from basicsr.utils.misc import get_device
+import pdb
+
+# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+class RealESRGANer():
+ """A helper class for upsampling images with RealESRGAN.
+
+ Args:
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+ model (nn.Module): The defined network. Default: None.
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
+ 0 denotes for do not use tile. Default: 0.
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+ half (float): Whether to use half precision during inference. Default: False.
+ """
+
+ def __init__(self,
+ scale,
+ model_path,
+ model=None,
+ tile=0,
+ tile_pad=10,
+ pre_pad=10,
+ half=False,
+ device=None,
+ gpu_id=None):
+ self.scale = scale
+ self.tile_size = tile
+ self.tile_pad = tile_pad
+ self.pre_pad = pre_pad
+ self.mod_scale = None
+ self.half = half
+
+ # initialize model
+ # if gpu_id:
+ # self.device = torch.device(
+ # f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+ # else:
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+
+ self.device = get_device(gpu_id) if device is None else device
+
+ # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
+ if model_path.startswith('https://'):
+ model_path = load_file_from_url(
+ url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
+ # prefer to use params_ema
+ if 'params_ema' in loadnet:
+ keyname = 'params_ema'
+ else:
+ keyname = 'params'
+ model.load_state_dict(loadnet[keyname], strict=True)
+ model.eval()
+ self.model = model.to(self.device)
+ if self.half:
+ self.model = self.model.half()
+
+ def pre_process(self, img):
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+ """
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+ self.img = img.unsqueeze(0).to(self.device)
+ if self.half:
+ self.img = self.img.half()
+
+ # pre_pad
+ if self.pre_pad != 0:
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
+ # mod pad for divisible borders
+ if self.scale == 2:
+ self.mod_scale = 2
+ elif self.scale == 1:
+ self.mod_scale = 4
+ if self.mod_scale is not None:
+ self.mod_pad_h, self.mod_pad_w = 0, 0
+ _, _, h, w = self.img.size()
+ if (h % self.mod_scale != 0):
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
+ if (w % self.mod_scale != 0):
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
+
+ def process(self):
+ # model inference
+ self.output = self.model(self.img)
+
+ def tile_process(self):
+ """It will first crop input images to tiles, and then process each tile.
+ Finally, all the processed tiles are merged into one images.
+
+ Modified from: https://github.com/ata4/esrgan-launcher
+ """
+ batch, channel, height, width = self.img.shape
+ output_height = height * self.scale
+ output_width = width * self.scale
+ output_shape = (batch, channel, output_height, output_width)
+
+ # start with black image
+ self.output = self.img.new_zeros(output_shape)
+ tiles_x = math.ceil(width / self.tile_size)
+ tiles_y = math.ceil(height / self.tile_size)
+
+ # loop over all tiles
+ for y in range(tiles_y):
+ for x in range(tiles_x):
+ # extract tile from input image
+ ofs_x = x * self.tile_size
+ ofs_y = y * self.tile_size
+ # input tile area on total image
+ input_start_x = ofs_x
+ input_end_x = min(ofs_x + self.tile_size, width)
+ input_start_y = ofs_y
+ input_end_y = min(ofs_y + self.tile_size, height)
+
+ # input tile area on total image with padding
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+ # input tile dimensions
+ input_tile_width = input_end_x - input_start_x
+ input_tile_height = input_end_y - input_start_y
+ tile_idx = y * tiles_x + x + 1
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+ # upscale tile
+ # pdb.set_trace()
+ try:
+ with torch.no_grad():
+ output_tile = self.model(input_tile)
+ except RuntimeError as error:
+ print('Error', error)
+ # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+ # output tile area on total image
+ output_start_x = input_start_x * self.scale
+ output_end_x = input_end_x * self.scale
+ output_start_y = input_start_y * self.scale
+ output_end_y = input_end_y * self.scale
+
+ # output tile area without padding
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+ # put tile into output image
+ self.output[:, :, output_start_y:output_end_y,
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
+ output_start_x_tile:output_end_x_tile]
+
+ def post_process(self):
+ # remove extra pad
+ if self.mod_scale is not None:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
+ # remove prepad
+ if self.pre_pad != 0:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
+ return self.output
+
+ @torch.no_grad()
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
+ h_input, w_input = img.shape[0:2]
+ # img: numpy
+ img = img.astype(np.float32)
+ if np.max(img) > 256: # 16-bit image
+ max_range = 65535
+ print('\tInput is a 16-bit image')
+ else:
+ max_range = 255
+ img = img / max_range
+ if len(img.shape) == 2: # gray image
+ img_mode = 'L'
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ elif img.shape[2] == 4: # RGBA image with alpha channel
+ img_mode = 'RGBA'
+ alpha = img[:, :, 3]
+ img = img[:, :, 0:3]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if alpha_upsampler == 'realesrgan':
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+ else:
+ img_mode = 'RGB'
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # ------------------- process image (without the alpha channel) ------------------- #
+ try:
+ with torch.no_grad():
+ # pdb.set_trace()
+ self.pre_process(img)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_img_t = self.post_process()
+ output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+ if img_mode == 'L':
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+ del output_img_t
+ torch.cuda.empty_cache()
+ except RuntimeError as error:
+ print(f"Failed inference for RealESRGAN: {error}")
+
+ # ------------------- process the alpha channel if necessary ------------------- #
+ if img_mode == 'RGBA':
+ if alpha_upsampler == 'realesrgan':
+ self.pre_process(alpha)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_alpha = self.post_process()
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+ else: # use the cv2 resize for alpha channel
+ h, w = alpha.shape[0:2]
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+ # merge the alpha channel
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+ output_img[:, :, 3] = output_alpha
+
+ # ------------------------------ return ------------------------------ #
+ if max_range == 65535: # 16-bit image
+ output = (output_img * 65535.0).round().astype(np.uint16)
+ else:
+ output = (output_img * 255.0).round().astype(np.uint8)
+
+ if outscale is not None and outscale != float(self.scale):
+ output = cv2.resize(
+ output, (
+ int(w_input * outscale),
+ int(h_input * outscale),
+ ), interpolation=cv2.INTER_LANCZOS4)
+
+ return output, img_mode
+
+
+class PrefetchReader(threading.Thread):
+ """Prefetch images.
+
+ Args:
+ img_list (list[str]): A image list of image paths to be read.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, img_list, num_prefetch_queue):
+ super().__init__()
+ self.que = queue.Queue(num_prefetch_queue)
+ self.img_list = img_list
+
+ def run(self):
+ for img_path in self.img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ self.que.put(img)
+
+ self.que.put(None)
+
+ def __next__(self):
+ next_item = self.que.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class IOConsumer(threading.Thread):
+
+ def __init__(self, opt, que, qid):
+ super().__init__()
+ self._queue = que
+ self.qid = qid
+ self.opt = opt
+
+ def run(self):
+ while True:
+ msg = self._queue.get()
+ if isinstance(msg, str) and msg == 'quit':
+ break
+
+ output = msg['output']
+ save_path = msg['save_path']
+ cv2.imwrite(save_path, output)
+ print(f'IO worker {self.qid} is done.')
\ No newline at end of file
diff --git a/basicsr/utils/registry.py b/basicsr/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..655753b3b9cbd0cfe73fe93a77cf1fcc3db6d827
--- /dev/null
+++ b/basicsr/utils/registry.py
@@ -0,0 +1,82 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
+
+
+class Registry():
+ """
+ The registry that provides name -> object mapping, to support third-party
+ users' custom modules.
+
+ To create a registry (e.g. a backbone registry):
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY = Registry('BACKBONE')
+
+ To register an object:
+
+ .. code-block:: python
+
+ @BACKBONE_REGISTRY.register()
+ class MyBackbone():
+ ...
+
+ Or:
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY.register(MyBackbone)
+ """
+
+ def __init__(self, name):
+ """
+ Args:
+ name (str): the name of this registry
+ """
+ self._name = name
+ self._obj_map = {}
+
+ def _do_register(self, name, obj):
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
+ f"in '{self._name}' registry!")
+ self._obj_map[name] = obj
+
+ def register(self, obj=None):
+ """
+ Register the given object under the the name `obj.__name__`.
+ Can be used as either a decorator or not.
+ See docstring of this class for usage.
+ """
+ if obj is None:
+ # used as a decorator
+ def deco(func_or_class):
+ name = func_or_class.__name__
+ self._do_register(name, func_or_class)
+ return func_or_class
+
+ return deco
+
+ # used as a function call
+ name = obj.__name__
+ self._do_register(name, obj)
+
+ def get(self, name):
+ ret = self._obj_map.get(name)
+ if ret is None:
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+ return ret
+
+ def __contains__(self, name):
+ return name in self._obj_map
+
+ def __iter__(self):
+ return iter(self._obj_map.items())
+
+ def keys(self):
+ return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry('dataset')
+ARCH_REGISTRY = Registry('arch')
+MODEL_REGISTRY = Registry('model')
+LOSS_REGISTRY = Registry('loss')
+METRIC_REGISTRY = Registry('metric')
diff --git a/basicsr/utils/video_util.py b/basicsr/utils/video_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..80e1a3370b8eca64c84469063eac7df353ba7b28
--- /dev/null
+++ b/basicsr/utils/video_util.py
@@ -0,0 +1,127 @@
+'''
+The code is modified from the Real-ESRGAN:
+https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py
+
+'''
+import cv2
+import sys
+import numpy as np
+
+try:
+ import ffmpeg
+except ImportError:
+ import pip
+ pip.main(['install', '--user', 'ffmpeg-python'])
+ import ffmpeg
+
+def get_video_meta_info(video_path):
+ ret = {}
+ probe = ffmpeg.probe(video_path)
+ video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
+ has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
+ ret['width'] = video_streams[0]['width']
+ ret['height'] = video_streams[0]['height']
+ ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
+ ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
+ ret['nb_frames'] = int(video_streams[0]['nb_frames'])
+ return ret
+
+class VideoReader:
+ def __init__(self, video_path):
+ self.paths = [] # for image&folder type
+ self.audio = None
+ try:
+ self.stream_reader = (
+ ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
+ loglevel='error').run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
+ except FileNotFoundError:
+ print('Please install ffmpeg (not ffmpeg-python) by running\n',
+ '\t$ conda install -c conda-forge ffmpeg')
+ sys.exit(0)
+
+ meta = get_video_meta_info(video_path)
+ self.width = meta['width']
+ self.height = meta['height']
+ self.input_fps = meta['fps']
+ self.audio = meta['audio']
+ self.nb_frames = meta['nb_frames']
+
+ self.idx = 0
+
+ def get_resolution(self):
+ return self.height, self.width
+
+ def get_fps(self):
+ if self.input_fps is not None:
+ return self.input_fps
+ return 24
+
+ def get_audio(self):
+ return self.audio
+
+ def __len__(self):
+ return self.nb_frames
+
+ def get_frame_from_stream(self):
+ img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
+ if not img_bytes:
+ return None
+ img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
+ return img
+
+ def get_frame_from_list(self):
+ if self.idx >= self.nb_frames:
+ return None
+ img = cv2.imread(self.paths[self.idx])
+ self.idx += 1
+ return img
+
+ def get_frame(self):
+ return self.get_frame_from_stream()
+
+
+ def close(self):
+ self.stream_reader.stdin.close()
+ self.stream_reader.wait()
+
+
+class VideoWriter:
+ def __init__(self, video_save_path, height, width, fps, audio=None):
+ if height > 2160:
+ print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
+ 'We highly recommend to decrease the outscale(aka, -s).')
+ if audio is not None:
+ self.stream_writer = (
+ ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}',
+ framerate=fps).output(
+ audio,
+ video_save_path,
+ pix_fmt='yuv420p',
+ vcodec='libx264',
+ loglevel='error',
+ acodec='copy').overwrite_output().run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
+ else:
+ self.stream_writer = (
+ ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}',
+ framerate=fps).output(
+ video_save_path,
+ pix_fmt='yuv420p',
+ vcodec='libx264',
+ loglevel='error').overwrite_output().run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
+
+ def write_frame(self, frame):
+ try:
+ frame = frame.astype(np.uint8).tobytes()
+ self.stream_writer.stdin.write(frame)
+ except BrokenPipeError:
+ print('Please re-install ffmpeg and libx264 by running\n',
+ '\t$ conda install -c conda-forge ffmpeg\n',
+ '\t$ conda install -c conda-forge x264')
+ sys.exit(0)
+
+ def close(self):
+ self.stream_writer.stdin.close()
+ self.stream_writer.wait()
\ No newline at end of file
diff --git a/facelib/detection/__init__.py b/facelib/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3112e10166eb92a72bd7450561446bd882ccac44
--- /dev/null
+++ b/facelib/detection/__init__.py
@@ -0,0 +1,71 @@
+import os
+import torch
+from torch import nn
+from copy import deepcopy
+
+from facelib.utils import load_file_from_url
+from facelib.utils import download_pretrained_models
+from facelib.detection.yolov5face.models.common import Conv
+
+from .retinaface.retinaface import RetinaFace
+from .yolov5face.face_detector import YoloDetector
+
+
+def init_detection_model(model_name, half=False, device='cuda'):
+ if 'retinaface' in model_name:
+ model = init_retinaface_model(model_name, half, device)
+ elif 'YOLOv5' in model_name:
+ model = init_yolov5face_model(model_name, device)
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ return model
+
+
+def init_retinaface_model(model_name, half=False, device='cuda'):
+ if model_name == 'retinaface_resnet50':
+ model = RetinaFace(network_name='resnet50', half=half)
+ model_url = 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/detection_Resnet50_Final.pth'
+ elif model_name == 'retinaface_mobile0.25':
+ model = RetinaFace(network_name='mobile0.25', half=half)
+ model_url = 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
+ # remove unnecessary 'module.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith('module.'):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ model.load_state_dict(load_net, strict=True)
+ model.eval()
+ model = model.to(device)
+
+ return model
+
+
+def init_yolov5face_model(model_name, device='cuda'):
+ if model_name == 'YOLOv5l':
+ model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
+ model_url = 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/yolov5l-face.pth'
+ elif model_name == 'YOLOv5n':
+ model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
+ model_url = 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/yolov5n-face.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
+ model.detector.load_state_dict(load_net, strict=True)
+ model.detector.eval()
+ model.detector = model.detector.to(device).float()
+
+ for m in model.detector.modules():
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+ m.inplace = True # pytorch 1.7.0 compatibility
+ elif isinstance(m, Conv):
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
+
+ return model
diff --git a/facelib/detection/align_trans.py b/facelib/detection/align_trans.py
new file mode 100644
index 0000000000000000000000000000000000000000..07f1eb365462c2ec5bbac6d1854c786b6fd6be90
--- /dev/null
+++ b/facelib/detection/align_trans.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+
+from .matlab_cp2tform import get_similarity_transform_for_cv2
+
+# reference facial points, a list of coordinates (x,y)
+REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
+ [33.54930115, 92.3655014], [62.72990036, 92.20410156]]
+
+DEFAULT_CROP_SIZE = (96, 112)
+
+
+class FaceWarpException(Exception):
+
+ def __str__(self):
+ return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
+ """
+ Function:
+ ----------
+ get reference 5 key points according to crop settings:
+ 0. Set default crop_size:
+ if default_square:
+ crop_size = (112, 112)
+ else:
+ crop_size = (96, 112)
+ 1. Pad the crop_size by inner_padding_factor in each side;
+ 2. Resize crop_size into (output_size - outer_padding*2),
+ pad into output_size with outer_padding;
+ 3. Output reference_5point;
+ Parameters:
+ ----------
+ @output_size: (w, h) or None
+ size of aligned face image
+ @inner_padding_factor: (w_factor, h_factor)
+ padding factor for inner (w, h)
+ @outer_padding: (w_pad, h_pad)
+ each row is a pair of coordinates (x, y)
+ @default_square: True or False
+ if True:
+ default crop_size = (112, 112)
+ else:
+ default crop_size = (96, 112);
+ !!! make sure, if output_size is not None:
+ (output_size - outer_padding)
+ = some_scale * (default crop_size * (1.0 +
+ inner_padding_factor))
+ Returns:
+ ----------
+ @reference_5point: 5x2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
+
+ # 0) make the inner region a square
+ if default_square:
+ size_diff = max(tmp_crop_size) - tmp_crop_size
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += size_diff
+
+ if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
+
+ return tmp_5pts
+
+ if (inner_padding_factor == 0 and outer_padding == (0, 0)):
+ if output_size is None:
+ return tmp_5pts
+ else:
+ raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
+
+ # check output size
+ if not (0 <= inner_padding_factor <= 1.0):
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
+
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
+ output_size = tmp_crop_size * \
+ (1 + inner_padding_factor * 2).astype(np.int32)
+ output_size += np.array(outer_padding)
+ if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
+
+ # 1) pad the inner region according inner_padding_factor
+ if inner_padding_factor > 0:
+ size_diff = tmp_crop_size * inner_padding_factor * 2
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
+
+ # 2) resize the padded inner region
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
+
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
+ raise FaceWarpException('Must have (output_size - outer_padding)'
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
+
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
+ tmp_5pts = tmp_5pts * scale_factor
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
+ # tmp_5pts = tmp_5pts + size_diff / 2
+ tmp_crop_size = size_bf_outer_pad
+
+ # 3) add outer_padding to make output_size
+ reference_5point = tmp_5pts + np.array(outer_padding)
+ tmp_crop_size = output_size
+
+ return reference_5point
+
+
+def get_affine_transform_matrix(src_pts, dst_pts):
+ """
+ Function:
+ ----------
+ get affine transform matrix 'tfm' from src_pts to dst_pts
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points matrix, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points matrix, each row is a pair of coordinates (x, y)
+ Returns:
+ ----------
+ @tfm: 2x3 np.array
+ transform matrix from src_pts to dst_pts
+ """
+
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
+ n_pts = src_pts.shape[0]
+ ones = np.ones((n_pts, 1), src_pts.dtype)
+ src_pts_ = np.hstack([src_pts, ones])
+ dst_pts_ = np.hstack([dst_pts, ones])
+
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
+
+ if rank == 3:
+ tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
+ elif rank == 2:
+ tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
+
+ return tfm
+
+
+def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+ Parameters:
+ ----------
+ @src_img: 3x3 np.array
+ input image
+ @facial_pts: could be
+ 1)a list of K coordinates (x,y)
+ or
+ 2) Kx2 or 2xK np.array
+ each row or col is a pair of coordinates (x, y)
+ @reference_pts: could be
+ 1) a list of K coordinates (x,y)
+ or
+ 2) Kx2 or 2xK np.array
+ each row or col is a pair of coordinates (x, y)
+ or
+ 3) None
+ if None, use default reference facial points
+ @crop_size: (w, h)
+ output face image size
+ @align_type: transform type, could be one of
+ 1) 'similarity': use similarity transform
+ 2) 'cv2_affine': use the first 3 points to do affine transform,
+ by calling cv2.getAffineTransform()
+ 3) 'affine': use all points to do affine transform
+ Returns:
+ ----------
+ @face_img: output face image with size (w, h) = @crop_size
+ """
+
+ if reference_pts is None:
+ if crop_size[0] == 96 and crop_size[1] == 112:
+ reference_pts = REFERENCE_FACIAL_POINTS
+ else:
+ default_square = False
+ inner_padding_factor = 0
+ outer_padding = (0, 0)
+ output_size = crop_size
+
+ reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
+ default_square)
+
+ ref_pts = np.float32(reference_pts)
+ ref_pts_shp = ref_pts.shape
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
+ raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
+
+ if ref_pts_shp[0] == 2:
+ ref_pts = ref_pts.T
+
+ src_pts = np.float32(facial_pts)
+ src_pts_shp = src_pts.shape
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
+ raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
+
+ if src_pts_shp[0] == 2:
+ src_pts = src_pts.T
+
+ if src_pts.shape != ref_pts.shape:
+ raise FaceWarpException('facial_pts and reference_pts must have the same shape')
+
+ if align_type == 'cv2_affine':
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
+ elif align_type == 'affine':
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
+ else:
+ tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
+
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
+
+ return face_img
diff --git a/facelib/detection/matlab_cp2tform.py b/facelib/detection/matlab_cp2tform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2a8b54a91709c71437e15c68d3be9a9b0a20a34
--- /dev/null
+++ b/facelib/detection/matlab_cp2tform.py
@@ -0,0 +1,317 @@
+import numpy as np
+from numpy.linalg import inv, lstsq
+from numpy.linalg import matrix_rank as rank
+from numpy.linalg import norm
+
+
+class MatlabCp2tormException(Exception):
+
+ def __str__(self):
+ return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def tformfwd(trans, uv):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+ uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
+ xy = np.dot(uv, trans)
+ xy = xy[:, 0:-1]
+ return xy
+
+
+def tforminv(trans, uv):
+ """
+ Function:
+ ----------
+ apply the inverse of affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of inverse-transformed coordinates (x, y)
+ """
+ Tinv = inv(trans)
+ xy = tformfwd(Tinv, uv)
+ return xy
+
+
+def findNonreflectiveSimilarity(uv, xy, options=None):
+ options = {'K': 2}
+
+ K = options['K']
+ M = xy.shape[0]
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
+ X = np.vstack((tmp1, tmp2))
+
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+ U = np.vstack((u, v))
+
+ # We know that X * r = U
+ if rank(X) >= 2 * K:
+ r, _, _, _ = lstsq(X, U, rcond=-1)
+ r = np.squeeze(r)
+ else:
+ raise Exception('cp2tform:twoUniquePointsReq')
+ sc = r[0]
+ ss = r[1]
+ tx = r[2]
+ ty = r[3]
+
+ Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
+ T = inv(Tinv)
+ T[:, 2] = np.array([0, 0, 1])
+
+ return T, Tinv
+
+
+def findSimilarity(uv, xy, options=None):
+ options = {'K': 2}
+
+ # uv = np.array(uv)
+ # xy = np.array(xy)
+
+ # Solve for trans1
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
+
+ # Solve for trans2
+
+ # manually reflect the xy data across the Y-axis
+ xyR = xy
+ xyR[:, 0] = -1 * xyR[:, 0]
+
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
+
+ # manually reflect the tform to undo the reflection done on xyR
+ TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+ trans2 = np.dot(trans2r, TreflectY)
+
+ # Figure out if trans1 or trans2 is better
+ xy1 = tformfwd(trans1, uv)
+ norm1 = norm(xy1 - xy)
+
+ xy2 = tformfwd(trans2, uv)
+ norm2 = norm(xy2 - xy)
+
+ if norm1 <= norm2:
+ return trans1, trans1_inv
+ else:
+ trans2_inv = inv(trans2)
+ return trans2, trans2_inv
+
+
+def get_similarity_transform(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'trans':
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y, 1] = [u, v, 1] * trans
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ @reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+ trans_inv: 3x3 np.array
+ inverse of trans, transform matrix from xy to uv
+ """
+
+ if reflective:
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
+ else:
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
+
+ return trans, trans_inv
+
+
+def cvt_tform_mat_for_cv2(trans):
+ """
+ Function:
+ ----------
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ cv2_trans = trans[:, 0:2].T
+
+ return cv2_trans
+
+
+def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
+
+ return cv2_trans
+
+
+if __name__ == '__main__':
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ # In Matlab, run:
+ #
+ # uv = [u'; v'];
+ # xy = [x'; y'];
+ # tform_sim=cp2tform(uv,xy,'similarity');
+ #
+ # trans = tform_sim.tdata.T
+ # ans =
+ # -0.0764 -1.6190 0
+ # 1.6190 -0.0764 0
+ # -3.2156 0.0290 1.0000
+ # trans_inv = tform_sim.tdata.Tinv
+ # ans =
+ #
+ # -0.0291 0.6163 0
+ # -0.6163 -0.0291 0
+ # -0.0756 1.9826 1.0000
+ # xy_m=tformfwd(tform_sim, u,v)
+ #
+ # xy_m =
+ #
+ # -3.2156 0.0290
+ # 1.1833 -9.9143
+ # 5.0323 2.8853
+ # uv_m=tforminv(tform_sim, x,y)
+ #
+ # uv_m =
+ #
+ # 0.5698 1.3953
+ # 6.0872 2.2733
+ # -2.6570 4.3314
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ uv = np.array((u, v)).T
+ xy = np.array((x, y)).T
+
+ print('\n--->uv:')
+ print(uv)
+ print('\n--->xy:')
+ print(xy)
+
+ trans, trans_inv = get_similarity_transform(uv, xy)
+
+ print('\n--->trans matrix:')
+ print(trans)
+
+ print('\n--->trans_inv matrix:')
+ print(trans_inv)
+
+ print('\n---> apply transform to uv')
+ print('\nxy_m = uv_augmented * trans')
+ uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
+ xy_m = np.dot(uv_aug, trans)
+ print(xy_m)
+
+ print('\nxy_m = tformfwd(trans, uv)')
+ xy_m = tformfwd(trans, uv)
+ print(xy_m)
+
+ print('\n---> apply inverse transform to xy')
+ print('\nuv_m = xy_augmented * trans_inv')
+ xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
+ uv_m = np.dot(xy_aug, trans_inv)
+ print(uv_m)
+
+ print('\nuv_m = tformfwd(trans_inv, xy)')
+ uv_m = tformfwd(trans_inv, xy)
+ print(uv_m)
+
+ uv_m = tforminv(trans, xy)
+ print('\nuv_m = tforminv(trans, xy)')
+ print(uv_m)
diff --git a/facelib/detection/retinaface/retinaface.py b/facelib/detection/retinaface/retinaface.py
new file mode 100644
index 0000000000000000000000000000000000000000..48e3fa76e4996278367876126eecf5e4b396f2de
--- /dev/null
+++ b/facelib/detection/retinaface/retinaface.py
@@ -0,0 +1,393 @@
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
+
+from facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face
+from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head
+from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
+ py_cpu_nms)
+
+from basicsr.utils.misc import get_device
+# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = get_device()
+# device = 'cpu'
+
+
+def generate_config(network_name):
+
+ cfg_mnet = {
+ 'name': 'mobilenet0.25',
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
+ 'steps': [8, 16, 32],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 32,
+ 'ngpu': 1,
+ 'epoch': 250,
+ 'decay1': 190,
+ 'decay2': 220,
+ 'image_size': 640,
+ 'return_layers': {
+ 'stage1': 1,
+ 'stage2': 2,
+ 'stage3': 3
+ },
+ 'in_channel': 32,
+ 'out_channel': 64
+ }
+
+ cfg_re50 = {
+ 'name': 'Resnet50',
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
+ 'steps': [8, 16, 32],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 24,
+ 'ngpu': 4,
+ 'epoch': 100,
+ 'decay1': 70,
+ 'decay2': 90,
+ 'image_size': 840,
+ 'return_layers': {
+ 'layer2': 1,
+ 'layer3': 2,
+ 'layer4': 3
+ },
+ 'in_channel': 256,
+ 'out_channel': 256
+ }
+
+ if network_name == 'mobile0.25':
+ return cfg_mnet
+ elif network_name == 'resnet50':
+ return cfg_re50
+ else:
+ raise NotImplementedError(f'network_name={network_name}')
+
+
+class RetinaFace(nn.Module):
+
+ def __init__(self, network_name='resnet50', half=False, phase='test'):
+ super(RetinaFace, self).__init__()
+ self.half_inference = half
+ cfg = generate_config(network_name)
+ self.backbone = cfg['name']
+
+ self.model_name = f'retinaface_{network_name}'
+ self.cfg = cfg
+ self.phase = phase
+ self.target_size, self.max_size = 1600, 2150
+ self.resize, self.scale, self.scale1 = 1., None, None
+ self.mean_tensor = torch.tensor(
+ [[[[104.]], [[117.]], [[123.]]]]).to(device)
+ self.reference = get_reference_facial_points(default_square=True)
+ # Build network.
+ backbone = None
+ if cfg['name'] == 'mobilenet0.25':
+ backbone = MobileNetV1()
+ self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+ elif cfg['name'] == 'Resnet50':
+ import torchvision.models as models
+ backbone = models.resnet50(pretrained=False)
+ self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+
+ in_channels_stage2 = cfg['in_channel']
+ in_channels_list = [
+ in_channels_stage2 * 2,
+ in_channels_stage2 * 4,
+ in_channels_stage2 * 8,
+ ]
+
+ out_channels = cfg['out_channel']
+ self.fpn = FPN(in_channels_list, out_channels)
+ self.ssh1 = SSH(out_channels, out_channels)
+ self.ssh2 = SSH(out_channels, out_channels)
+ self.ssh3 = SSH(out_channels, out_channels)
+
+ self.ClassHead = make_class_head(
+ fpn_num=3, inchannels=cfg['out_channel'])
+ self.BboxHead = make_bbox_head(
+ fpn_num=3, inchannels=cfg['out_channel'])
+ self.LandmarkHead = make_landmark_head(
+ fpn_num=3, inchannels=cfg['out_channel'])
+
+ self.to(device)
+ self.eval()
+ if self.half_inference:
+ self.half()
+
+ def forward(self, inputs):
+ out = self.body(inputs)
+
+ if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50':
+ out = list(out.values())
+ # FPN
+ fpn = self.fpn(out)
+
+ # SSH
+ feature1 = self.ssh1(fpn[0])
+ feature2 = self.ssh2(fpn[1])
+ feature3 = self.ssh3(fpn[2])
+ features = [feature1, feature2, feature3]
+
+ bbox_regressions = torch.cat(
+ [self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
+ classifications = torch.cat(
+ [self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
+ tmp = [self.LandmarkHead[i](feature)
+ for i, feature in enumerate(features)]
+ ldm_regressions = (torch.cat(tmp, dim=1))
+
+ if self.phase == 'train':
+ output = (bbox_regressions, classifications, ldm_regressions)
+ else:
+ output = (bbox_regressions, F.softmax(
+ classifications, dim=-1), ldm_regressions)
+ return output
+
+ def __detect_faces(self, inputs):
+ # get scale
+ height, width = inputs.shape[2:]
+ self.scale = torch.tensor(
+ [width, height, width, height], dtype=torch.float32).to(device)
+ tmp = [width, height, width, height, width,
+ height, width, height, width, height]
+ self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
+
+ # forawrd
+ inputs = inputs.to(device)
+ if self.half_inference:
+ inputs = inputs.half()
+ loc, conf, landmarks = self(inputs)
+
+ # get priorbox
+ priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
+ priors = priorbox.forward().to(device)
+
+ return loc, conf, landmarks, priors
+
+ # single image detection
+ def transform(self, image, use_origin_size):
+ # convert to opencv format
+ if isinstance(image, Image.Image):
+ image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+ image = image.astype(np.float32)
+
+ # testing scale
+ im_size_min = np.min(image.shape[0:2])
+ im_size_max = np.max(image.shape[0:2])
+ resize = float(self.target_size) / float(im_size_min)
+
+ # prevent bigger axis from being more than max_size
+ if np.round(resize * im_size_max) > self.max_size:
+ resize = float(self.max_size) / float(im_size_max)
+ resize = 1 if use_origin_size else resize
+
+ # resize
+ if resize != 1:
+ image = cv2.resize(image, None, None, fx=resize,
+ fy=resize, interpolation=cv2.INTER_LINEAR)
+
+ # convert to torch.tensor format
+ # image -= (104, 117, 123)
+ image = image.transpose(2, 0, 1)
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ return image, resize
+
+ def detect_faces(
+ self,
+ image,
+ conf_threshold=0.8,
+ nms_threshold=0.4,
+ use_origin_size=True,
+ ):
+ """
+ Params:
+ imgs: BGR image
+ """
+ image, self.resize = self.transform(image, use_origin_size)
+ image = image.to(device)
+ if self.half_inference:
+ image = image.half()
+ image = image - self.mean_tensor
+
+ loc, conf, landmarks, priors = self.__detect_faces(image)
+
+ boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance'])
+ boxes = boxes * self.scale / self.resize
+ boxes = boxes.cpu().numpy()
+
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
+
+ landmarks = decode_landm(landmarks.squeeze(
+ 0), priors, self.cfg['variance'])
+ landmarks = landmarks * self.scale1 / self.resize
+ landmarks = landmarks.cpu().numpy()
+
+ # ignore low scores
+ inds = np.where(scores > conf_threshold)[0]
+ boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
+
+ # sort
+ order = scores.argsort()[::-1]
+ boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
+
+ # do NMS
+ bounding_boxes = np.hstack(
+ (boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+ keep = py_cpu_nms(bounding_boxes, nms_threshold)
+ bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
+ # self.t['forward_pass'].toc()
+ # print(self.t['forward_pass'].average_time)
+ # import sys
+ # sys.stdout.flush()
+ return np.concatenate((bounding_boxes, landmarks), axis=1)
+
+ def __align_multi(self, image, boxes, landmarks, limit=None):
+
+ if len(boxes) < 1:
+ return [], []
+
+ if limit:
+ boxes = boxes[:limit]
+ landmarks = landmarks[:limit]
+
+ faces = []
+ for landmark in landmarks:
+ facial5points = [[landmark[2 * j], landmark[2 * j + 1]]
+ for j in range(5)]
+
+ warped_face = warp_and_crop_face(
+ np.array(image), facial5points, self.reference, crop_size=(112, 112))
+ faces.append(warped_face)
+
+ return np.concatenate((boxes, landmarks), axis=1), faces
+
+ def align_multi(self, img, conf_threshold=0.8, limit=None):
+
+ rlt = self.detect_faces(img, conf_threshold=conf_threshold)
+ boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
+
+ return self.__align_multi(img, boxes, landmarks, limit)
+
+ # batched detection
+ def batched_transform(self, frames, use_origin_size):
+ """
+ Arguments:
+ frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
+ type=np.float32, BGR format).
+ use_origin_size: whether to use origin size.
+ """
+ from_PIL = True if isinstance(frames[0], Image.Image) else False
+
+ # convert to opencv format
+ if from_PIL:
+ frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR)
+ for frame in frames]
+ frames = np.asarray(frames, dtype=np.float32)
+
+ # testing scale
+ im_size_min = np.min(frames[0].shape[0:2])
+ im_size_max = np.max(frames[0].shape[0:2])
+ resize = float(self.target_size) / float(im_size_min)
+
+ # prevent bigger axis from being more than max_size
+ if np.round(resize * im_size_max) > self.max_size:
+ resize = float(self.max_size) / float(im_size_max)
+ resize = 1 if use_origin_size else resize
+
+ # resize
+ if resize != 1:
+ if not from_PIL:
+ frames = F.interpolate(frames, scale_factor=resize)
+ else:
+ frames = [
+ cv2.resize(frame, None, None, fx=resize,
+ fy=resize, interpolation=cv2.INTER_LINEAR)
+ for frame in frames
+ ]
+
+ # convert to torch.tensor format
+ if not from_PIL:
+ frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
+ else:
+ frames = frames.transpose((0, 3, 1, 2))
+ frames = torch.from_numpy(frames)
+
+ return frames, resize
+
+ def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
+ """
+ Arguments:
+ frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
+ type=np.uint8, BGR format).
+ conf_threshold: confidence threshold.
+ nms_threshold: nms threshold.
+ use_origin_size: whether to use origin size.
+ Returns:
+ final_bounding_boxes: list of np.array ([n_boxes, 5],
+ type=np.float32).
+ final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
+ """
+ # self.t['forward_pass'].tic()
+ frames, self.resize = self.batched_transform(frames, use_origin_size)
+ frames = frames.to(device)
+ frames = frames - self.mean_tensor
+
+ b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
+
+ final_bounding_boxes, final_landmarks = [], []
+
+ # decode
+ priors = priors.unsqueeze(0)
+ b_loc = batched_decode(
+ b_loc, priors, self.cfg['variance']) * self.scale / self.resize
+ b_landmarks = batched_decode_landm(
+ b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize
+ b_conf = b_conf[:, :, 1]
+
+ # index for selection
+ b_indice = b_conf > conf_threshold
+
+ # concat
+ b_loc_and_conf = torch.cat(
+ (b_loc, b_conf.unsqueeze(-1)), dim=2).float()
+
+ for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
+
+ # ignore low scores
+ pred, landm = pred[inds, :], landm[inds, :]
+ if pred.shape[0] == 0:
+ final_bounding_boxes.append(np.array([], dtype=np.float32))
+ final_landmarks.append(np.array([], dtype=np.float32))
+ continue
+
+ # sort
+ # order = score.argsort(descending=True)
+ # box, landm, score = box[order], landm[order], score[order]
+
+ # to CPU
+ bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
+
+ # NMS
+ keep = py_cpu_nms(bounding_boxes, nms_threshold)
+ bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
+
+ # append
+ final_bounding_boxes.append(bounding_boxes)
+ final_landmarks.append(landmarks)
+ # self.t['forward_pass'].toc(average=True)
+ # self.batch_time += self.t['forward_pass'].diff
+ # self.total_frame += len(frames)
+ # print(self.batch_time / self.total_frame)
+
+ return final_bounding_boxes, final_landmarks
diff --git a/facelib/detection/retinaface/retinaface_net.py b/facelib/detection/retinaface/retinaface_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6aa82d3e9055a838f1f9076b12f05fdfc154d0
--- /dev/null
+++ b/facelib/detection/retinaface/retinaface_net.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv_bn(inp, oup, stride=1, leaky=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_bn_no_relu(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+
+def conv_bn1X1(inp, oup, stride, leaky=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_dw(inp, oup, stride, leaky=0.1):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ )
+
+
+class SSH(nn.Module):
+
+ def __init__(self, in_channel, out_channel):
+ super(SSH, self).__init__()
+ assert out_channel % 4 == 0
+ leaky = 0
+ if (out_channel <= 64):
+ leaky = 0.1
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
+
+ self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
+ self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+ self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
+ self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+ def forward(self, input):
+ conv3X3 = self.conv3X3(input)
+
+ conv5X5_1 = self.conv5X5_1(input)
+ conv5X5 = self.conv5X5_2(conv5X5_1)
+
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
+ conv7X7 = self.conv7x7_3(conv7X7_2)
+
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
+ out = F.relu(out)
+ return out
+
+
+class FPN(nn.Module):
+
+ def __init__(self, in_channels_list, out_channels):
+ super(FPN, self).__init__()
+ leaky = 0
+ if (out_channels <= 64):
+ leaky = 0.1
+ self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
+ self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
+ self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
+
+ self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
+ self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
+
+ def forward(self, input):
+ # names = list(input.keys())
+ # input = list(input.values())
+
+ output1 = self.output1(input[0])
+ output2 = self.output2(input[1])
+ output3 = self.output3(input[2])
+
+ up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
+ output2 = output2 + up3
+ output2 = self.merge2(output2)
+
+ up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
+ output1 = output1 + up2
+ output1 = self.merge1(output1)
+
+ out = [output1, output2, output3]
+ return out
+
+
+class MobileNetV1(nn.Module):
+
+ def __init__(self):
+ super(MobileNetV1, self).__init__()
+ self.stage1 = nn.Sequential(
+ conv_bn(3, 8, 2, leaky=0.1), # 3
+ conv_dw(8, 16, 1), # 7
+ conv_dw(16, 32, 2), # 11
+ conv_dw(32, 32, 1), # 19
+ conv_dw(32, 64, 2), # 27
+ conv_dw(64, 64, 1), # 43
+ )
+ self.stage2 = nn.Sequential(
+ conv_dw(64, 128, 2), # 43 + 16 = 59
+ conv_dw(128, 128, 1), # 59 + 32 = 91
+ conv_dw(128, 128, 1), # 91 + 32 = 123
+ conv_dw(128, 128, 1), # 123 + 32 = 155
+ conv_dw(128, 128, 1), # 155 + 32 = 187
+ conv_dw(128, 128, 1), # 187 + 32 = 219
+ )
+ self.stage3 = nn.Sequential(
+ conv_dw(128, 256, 2), # 219 +3 2 = 241
+ conv_dw(256, 256, 1), # 241 + 64 = 301
+ )
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(256, 1000)
+
+ def forward(self, x):
+ x = self.stage1(x)
+ x = self.stage2(x)
+ x = self.stage3(x)
+ x = self.avg(x)
+ # x = self.model(x)
+ x = x.view(-1, 256)
+ x = self.fc(x)
+ return x
+
+
+class ClassHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(ClassHead, self).__init__()
+ self.num_anchors = num_anchors
+ self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 2)
+
+
+class BboxHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(BboxHead, self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 4)
+
+
+class LandmarkHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(LandmarkHead, self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 10)
+
+
+def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
+ classhead = nn.ModuleList()
+ for i in range(fpn_num):
+ classhead.append(ClassHead(inchannels, anchor_num))
+ return classhead
+
+
+def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
+ bboxhead = nn.ModuleList()
+ for i in range(fpn_num):
+ bboxhead.append(BboxHead(inchannels, anchor_num))
+ return bboxhead
+
+
+def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
+ landmarkhead = nn.ModuleList()
+ for i in range(fpn_num):
+ landmarkhead.append(LandmarkHead(inchannels, anchor_num))
+ return landmarkhead
diff --git a/facelib/detection/retinaface/retinaface_utils.py b/facelib/detection/retinaface/retinaface_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c357757741c6d9bd7ce4d8ce740fefd51850fbf
--- /dev/null
+++ b/facelib/detection/retinaface/retinaface_utils.py
@@ -0,0 +1,421 @@
+import numpy as np
+import torch
+import torchvision
+from itertools import product as product
+from math import ceil
+
+
+class PriorBox(object):
+
+ def __init__(self, cfg, image_size=None, phase='train'):
+ super(PriorBox, self).__init__()
+ self.min_sizes = cfg['min_sizes']
+ self.steps = cfg['steps']
+ self.clip = cfg['clip']
+ self.image_size = image_size
+ self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
+ self.name = 's'
+
+ def forward(self):
+ anchors = []
+ for k, f in enumerate(self.feature_maps):
+ min_sizes = self.min_sizes[k]
+ for i, j in product(range(f[0]), range(f[1])):
+ for min_size in min_sizes:
+ s_kx = min_size / self.image_size[1]
+ s_ky = min_size / self.image_size[0]
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
+ for cy, cx in product(dense_cy, dense_cx):
+ anchors += [cx, cy, s_kx, s_ky]
+
+ # back to torch land
+ output = torch.Tensor(anchors).view(-1, 4)
+ if self.clip:
+ output.clamp_(max=1, min=0)
+ return output
+
+
+def py_cpu_nms(dets, thresh):
+ """Pure Python NMS baseline."""
+ keep = torchvision.ops.nms(
+ boxes=torch.Tensor(dets[:, :4]),
+ scores=torch.Tensor(dets[:, 4]),
+ iou_threshold=thresh,
+ )
+
+ return list(keep)
+
+
+def point_form(boxes):
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
+ representation for comparison to point form ground truth data.
+ Args:
+ boxes: (tensor) center-size default boxes from priorbox layers.
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat(
+ (
+ boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
+ boxes[:, :2] + boxes[:, 2:] / 2),
+ 1) # xmax, ymax
+
+
+def center_size(boxes):
+ """ Convert prior_boxes to (cx, cy, w, h)
+ representation for comparison to center-size form ground truth data.
+ Args:
+ boxes: (tensor) point_form boxes
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat(
+ (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
+ boxes[:, 2:] - boxes[:, :2],
+ 1) # w, h
+
+
+def intersect(box_a, box_b):
+ """ We resize both tensors to [A,B,2] without new malloc:
+ [A,2] -> [A,1,2] -> [A,B,2]
+ [B,2] -> [1,B,2] -> [A,B,2]
+ Then we compute the area of intersect between box_a and box_b.
+ Args:
+ box_a: (tensor) bounding boxes, Shape: [A,4].
+ box_b: (tensor) bounding boxes, Shape: [B,4].
+ Return:
+ (tensor) intersection area, Shape: [A,B].
+ """
+ A = box_a.size(0)
+ B = box_b.size(0)
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
+ inter = torch.clamp((max_xy - min_xy), min=0)
+ return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
+ is simply the intersection over union of two boxes. Here we operate on
+ ground truth boxes and default boxes.
+ E.g.:
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+ Args:
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
+ Return:
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
+ """
+ inter = intersect(box_a, box_b)
+ area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
+ area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
+ union = area_a + area_b - inter
+ return inter / union # [A,B]
+
+
+def matrix_iou(a, b):
+ """
+ return iou of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
+
+
+def matrix_iof(a, b):
+ """
+ return iof of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
+
+
+def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
+ """Match each prior box with the ground truth box of the highest jaccard
+ overlap, encode the bounding boxes, then return the matched indices
+ corresponding to both confidence and location preds.
+ Args:
+ threshold: (float) The overlap threshold used when matching boxes.
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
+ variances: (tensor) Variances corresponding to each prior coord,
+ Shape: [num_priors, 4].
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
+ landms: (tensor) Ground truth landms, Shape [num_obj, 10].
+ loc_t: (tensor) Tensor to be filled w/ encoded location targets.
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
+ landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
+ idx: (int) current batch index
+ Return:
+ The matched indices corresponding to 1)location 2)confidence
+ 3)landm preds.
+ """
+ # jaccard index
+ overlaps = jaccard(truths, point_form(priors))
+ # (Bipartite Matching)
+ # [1,num_objects] best prior for each ground truth
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
+
+ # ignore hard gt
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
+ if best_prior_idx_filter.shape[0] <= 0:
+ loc_t[idx] = 0
+ conf_t[idx] = 0
+ return
+
+ # [1,num_priors] best ground truth for each prior
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
+ best_truth_idx.squeeze_(0)
+ best_truth_overlap.squeeze_(0)
+ best_prior_idx.squeeze_(1)
+ best_prior_idx_filter.squeeze_(1)
+ best_prior_overlap.squeeze_(1)
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
+ # TODO refactor: index best_prior_idx with long tensor
+ # ensure every gt matches with its prior of max overlap
+ for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
+ best_truth_idx[best_prior_idx[j]] = j
+ matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
+ conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
+ conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
+ loc = encode(matches, priors, variances)
+
+ matches_landm = landms[best_truth_idx]
+ landm = encode_landm(matches_landm, priors, variances)
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
+ conf_t[idx] = conf # [num_priors] top class label for each prior
+ landm_t[idx] = landm
+
+
+def encode(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 4].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded boxes (tensor), Shape: [num_priors, 4]
+ """
+
+ # dist b/t match center and prior's center
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, 2:])
+ # match wh / prior wh
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+ g_wh = torch.log(g_wh) / variances[1]
+ # return target for smooth_l1_loss
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
+
+
+def encode_landm(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 10].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded landm (tensor), Shape: [num_priors, 10]
+ """
+
+ # dist b/t match center and prior's center
+ matched = torch.reshape(matched, (matched.size(0), 5, 2))
+ priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
+ g_cxcy = matched[:, :, :2] - priors[:, :, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, :, 2:])
+ # g_cxcy /= priors[:, :, 2:]
+ g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
+ # return target for smooth_l1_loss
+ return g_cxcy
+
+
+# Adapted from https://github.com/Hakuyume/chainer-ssd
+def decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
+
+
+def decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ tmp = (
+ priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
+ )
+ landms = torch.cat(tmp, dim=1)
+ return landms
+
+
+def batched_decode(b_loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ b_loc (tensor): location predictions for loc layers,
+ Shape: [num_batches,num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [1,num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+ boxes = (
+ priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
+ )
+ boxes = torch.cat(boxes, dim=2)
+
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
+ boxes[:, :, 2:] += boxes[:, :, :2]
+ return boxes
+
+
+def batched_decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_batches,num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [1,num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ landms = (
+ priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
+ )
+ landms = torch.cat(landms, dim=2)
+ return landms
+
+
+def log_sum_exp(x):
+ """Utility function for computing log_sum_exp while determining
+ This will be used to determine unaveraged confidence loss across
+ all examples in a batch.
+ Args:
+ x (Variable(tensor)): conf_preds from conf layers
+ """
+ x_max = x.data.max()
+ return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
+
+
+# Original author: Francisco Massa:
+# https://github.com/fmassa/object-detection.torch
+# Ported to PyTorch by Max deGroot (02/01/2017)
+def nms(boxes, scores, overlap=0.5, top_k=200):
+ """Apply non-maximum suppression at test time to avoid detecting too many
+ overlapping bounding boxes for a given object.
+ Args:
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
+ top_k: (int) The Maximum number of box preds to consider.
+ Return:
+ The indices of the kept boxes with respect to num_priors.
+ """
+
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
+ if boxes.numel() == 0:
+ return keep
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ area = torch.mul(x2 - x1, y2 - y1)
+ v, idx = scores.sort(0) # sort in ascending order
+ # I = I[v >= 0.01]
+ idx = idx[-top_k:] # indices of the top-k largest vals
+ xx1 = boxes.new()
+ yy1 = boxes.new()
+ xx2 = boxes.new()
+ yy2 = boxes.new()
+ w = boxes.new()
+ h = boxes.new()
+
+ # keep = torch.Tensor()
+ count = 0
+ while idx.numel() > 0:
+ i = idx[-1] # index of current largest val
+ # keep.append(i)
+ keep[count] = i
+ count += 1
+ if idx.size(0) == 1:
+ break
+ idx = idx[:-1] # remove kept element from view
+ # load bboxes of next highest vals
+ torch.index_select(x1, 0, idx, out=xx1)
+ torch.index_select(y1, 0, idx, out=yy1)
+ torch.index_select(x2, 0, idx, out=xx2)
+ torch.index_select(y2, 0, idx, out=yy2)
+ # store element-wise max with next highest score
+ xx1 = torch.clamp(xx1, min=x1[i])
+ yy1 = torch.clamp(yy1, min=y1[i])
+ xx2 = torch.clamp(xx2, max=x2[i])
+ yy2 = torch.clamp(yy2, max=y2[i])
+ w.resize_as_(xx2)
+ h.resize_as_(yy2)
+ w = xx2 - xx1
+ h = yy2 - yy1
+ # check sizes of xx1 and xx2.. after each iteration
+ w = torch.clamp(w, min=0.0)
+ h = torch.clamp(h, min=0.0)
+ inter = w * h
+ # IoU = i / (area(a) + area(b) - i)
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
+ union = (rem_areas - inter) + area[i]
+ IoU = inter / union # store result in iou
+ # keep only elements with an IoU <= overlap
+ idx = idx[IoU.le(overlap)]
+ return keep, count
diff --git a/facelib/detection/yolov5face/__init__.py b/facelib/detection/yolov5face/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/facelib/detection/yolov5face/face_detector.py b/facelib/detection/yolov5face/face_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b27e970e34f5dd1a945928b4e8eda171af36dab
--- /dev/null
+++ b/facelib/detection/yolov5face/face_detector.py
@@ -0,0 +1,141 @@
+import cv2
+import copy
+import re
+import torch
+import numpy as np
+
+from pathlib import Path
+from facelib.detection.yolov5face.models.yolo import Model
+from facelib.detection.yolov5face.utils.datasets import letterbox
+from facelib.detection.yolov5face.utils.general import (
+ check_img_size,
+ non_max_suppression_face,
+ scale_coords,
+ scale_coords_landmarks,
+)
+
+# IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9)
+IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
+ torch.__version__)[0][:3])] >= [1, 9, 0]
+
+
+def isListempty(inList):
+ if isinstance(inList, list): # Is a list
+ return all(map(isListempty, inList))
+ return False # Not a list
+
+class YoloDetector:
+ def __init__(
+ self,
+ config_name,
+ min_face=10,
+ target_size=None,
+ device='cuda',
+ ):
+ """
+ config_name: name of .yaml config with network configuration from models/ folder.
+ min_face : minimal face size in pixels.
+ target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
+ None for original resolution.
+ """
+ self._class_path = Path(__file__).parent.absolute()
+ self.target_size = target_size
+ self.min_face = min_face
+ self.detector = Model(cfg=config_name)
+ self.device = device
+
+
+ def _preprocess(self, imgs):
+ """
+ Preprocessing image before passing through the network. Resize and conversion to torch tensor.
+ """
+ pp_imgs = []
+ for img in imgs:
+ h0, w0 = img.shape[:2] # orig hw
+ if self.target_size:
+ r = self.target_size / min(h0, w0) # resize image to img_size
+ if r < 1:
+ img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
+
+ imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size
+ img = letterbox(img, new_shape=imgsz)[0]
+ pp_imgs.append(img)
+ pp_imgs = np.array(pp_imgs)
+ pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
+ pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
+ pp_imgs = pp_imgs.float() # uint8 to fp16/32
+ return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0
+
+ def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
+ """
+ Postprocessing of raw pytorch model output.
+ Returns:
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+ """
+ bboxes = [[] for _ in range(len(origimgs))]
+ landmarks = [[] for _ in range(len(origimgs))]
+
+ pred = non_max_suppression_face(pred, conf_thres, iou_thres)
+
+ for image_id, origimg in enumerate(origimgs):
+ img_shape = origimg.shape
+ image_height, image_width = img_shape[:2]
+ gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
+ gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks
+ det = pred[image_id].cpu()
+ scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
+ scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round()
+
+ for j in range(det.size()[0]):
+ box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
+ box = list(
+ map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
+ )
+ if box[3] - box[1] < self.min_face:
+ continue
+ lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
+ lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
+ lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
+ bboxes[image_id].append(box)
+ landmarks[image_id].append(lm)
+ return bboxes, landmarks
+
+ def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
+ """
+ Get bbox coordinates and keypoints of faces on original image.
+ Params:
+ imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
+ conf_thres: confidence threshold for each prediction
+ iou_thres: threshold for NMS (filter of intersecting bboxes)
+ Returns:
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+ """
+ # Pass input images through face detector
+ images = imgs if isinstance(imgs, list) else [imgs]
+ images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
+ origimgs = copy.deepcopy(images)
+
+ images = self._preprocess(images)
+
+ if IS_HIGH_VERSION:
+ with torch.inference_mode(): # for pytorch>=1.9
+ pred = self.detector(images)[0]
+ else:
+ with torch.no_grad(): # for pytorch<1.9
+ pred = self.detector(images)[0]
+
+ bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
+
+ # return bboxes, points
+ if not isListempty(points):
+ bboxes = np.array(bboxes).reshape(-1,4)
+ points = np.array(points).reshape(-1,10)
+ padding = bboxes[:,0].reshape(-1,1)
+ return np.concatenate((bboxes, padding, points), axis=1)
+ else:
+ return None
+
+ def __call__(self, *args):
+ return self.predict(*args)
diff --git a/facelib/detection/yolov5face/models/__init__.py b/facelib/detection/yolov5face/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/facelib/detection/yolov5face/models/common.py b/facelib/detection/yolov5face/models/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..497a00444c4c59725001993a63fe4617e9d323c8
--- /dev/null
+++ b/facelib/detection/yolov5face/models/common.py
@@ -0,0 +1,299 @@
+# This file contains modules common to various models
+
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.utils.datasets import letterbox
+from facelib.detection.yolov5face.utils.general import (
+ make_divisible,
+ non_max_suppression,
+ scale_coords,
+ xyxy2xywh,
+)
+
+
+def autopad(k, p=None): # kernel, padding
+ # Pad to 'same'
+ if p is None:
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
+ return p
+
+
+def channel_shuffle(x, groups):
+ batchsize, num_channels, height, width = x.data.size()
+ channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc")
+
+ # reshape
+ x = x.view(batchsize, groups, channels_per_group, height, width)
+ x = torch.transpose(x, 1, 2).contiguous()
+
+ # flatten
+ return x.view(batchsize, -1, height, width)
+
+
+def DWConv(c1, c2, k=1, s=1, act=True):
+ # Depthwise convolution
+ return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
+
+
+class Conv(nn.Module):
+ # Standard convolution
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
+ super().__init__()
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
+ self.bn = nn.BatchNorm2d(c2)
+ self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
+
+ def forward(self, x):
+ return self.act(self.bn(self.conv(x)))
+
+ def fuseforward(self, x):
+ return self.act(self.conv(x))
+
+
+class StemBlock(nn.Module):
+ def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True):
+ super().__init__()
+ self.stem_1 = Conv(c1, c2, k, s, p, g, act)
+ self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0)
+ self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1)
+ self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
+ self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0)
+
+ def forward(self, x):
+ stem_1_out = self.stem_1(x)
+ stem_2a_out = self.stem_2a(stem_1_out)
+ stem_2b_out = self.stem_2b(stem_2a_out)
+ stem_2p_out = self.stem_2p(stem_1_out)
+ return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1))
+
+
+class Bottleneck(nn.Module):
+ # Standard bottleneck
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class BottleneckCSP(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
+ self.act = nn.LeakyReLU(0.1, inplace=True)
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ y1 = self.cv3(self.m(self.cv1(x)))
+ y2 = self.cv2(x)
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
+
+
+class C3(nn.Module):
+ # CSP Bottleneck with 3 convolutions
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c1, c_, 1, 1)
+ self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
+
+
+class ShuffleV2Block(nn.Module):
+ def __init__(self, inp, oup, stride):
+ super().__init__()
+
+ if not 1 <= stride <= 3:
+ raise ValueError("illegal stride value")
+ self.stride = stride
+
+ branch_features = oup // 2
+
+ if self.stride > 1:
+ self.branch1 = nn.Sequential(
+ self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
+ nn.BatchNorm2d(inp),
+ nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ )
+ else:
+ self.branch1 = nn.Sequential()
+
+ self.branch2 = nn.Sequential(
+ nn.Conv2d(
+ inp if (self.stride > 1) else branch_features,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ ),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
+ nn.BatchNorm2d(branch_features),
+ nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ )
+
+ @staticmethod
+ def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
+ return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
+
+ def forward(self, x):
+ if self.stride == 1:
+ x1, x2 = x.chunk(2, dim=1)
+ out = torch.cat((x1, self.branch2(x2)), dim=1)
+ else:
+ out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
+ out = channel_shuffle(out, 2)
+ return out
+
+
+class SPP(nn.Module):
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
+ def __init__(self, c1, c2, k=(5, 9, 13)):
+ super().__init__()
+ c_ = c1 // 2 # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
+
+ def forward(self, x):
+ x = self.cv1(x)
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
+
+
+class Focus(nn.Module):
+ # Focus wh information into c-space
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
+ super().__init__()
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
+
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
+ return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
+
+
+class Concat(nn.Module):
+ # Concatenate a list of tensors along dimension
+ def __init__(self, dimension=1):
+ super().__init__()
+ self.d = dimension
+
+ def forward(self, x):
+ return torch.cat(x, self.d)
+
+
+class NMS(nn.Module):
+ # Non-Maximum Suppression (NMS) module
+ conf = 0.25 # confidence threshold
+ iou = 0.45 # IoU threshold
+ classes = None # (optional list) filter by class
+
+ def forward(self, x):
+ return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
+
+
+class AutoShape(nn.Module):
+ # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
+ img_size = 640 # inference size (pixels)
+ conf = 0.25 # NMS confidence threshold
+ iou = 0.45 # NMS IoU threshold
+ classes = None # (optional list) filter by class
+
+ def __init__(self, model):
+ super().__init__()
+ self.model = model.eval()
+
+ def autoshape(self):
+ print("autoShape already enabled, skipping... ") # model already converted to model.autoshape()
+ return self
+
+ def forward(self, imgs, size=640, augment=False, profile=False):
+ # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
+ # PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
+ # numpy: = np.zeros((720,1280,3)) # HWC
+ # torch: = torch.zeros(16,3,720,1280) # BCHW
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
+
+ p = next(self.model.parameters()) # for device and type
+ if isinstance(imgs, torch.Tensor): # torch
+ return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
+
+ # Pre-process
+ n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
+ shape0, shape1 = [], [] # image and inference shapes
+ for i, im in enumerate(imgs):
+ im = np.array(im) # to numpy
+ if im.shape[0] < 5: # image in CHW
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
+ im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
+ s = im.shape[:2] # HWC
+ shape0.append(s) # image shape
+ g = size / max(s) # gain
+ shape1.append([y * g for y in s])
+ imgs[i] = im # update
+ shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
+ x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
+ x = np.stack(x, 0) if n > 1 else x[0][None] # stack
+ x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32
+
+ # Inference
+ with torch.no_grad():
+ y = self.model(x, augment, profile)[0] # forward
+ y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
+
+ # Post-process
+ for i in range(n):
+ scale_coords(shape1, y[i][:, :4], shape0[i])
+
+ return Detections(imgs, y, self.names)
+
+
+class Detections:
+ # detections class for YOLOv5 inference results
+ def __init__(self, imgs, pred, names=None):
+ super().__init__()
+ d = pred[0].device # device
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations
+ self.imgs = imgs # list of images as numpy arrays
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
+ self.names = names # class names
+ self.xyxy = pred # xyxy pixels
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
+ self.n = len(self.pred)
+
+ def __len__(self):
+ return self.n
+
+ def tolist(self):
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
+ x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
+ for d in x:
+ for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]:
+ setattr(d, k, getattr(d, k)[0]) # pop out of list
+ return x
diff --git a/facelib/detection/yolov5face/models/experimental.py b/facelib/detection/yolov5face/models/experimental.py
new file mode 100644
index 0000000000000000000000000000000000000000..37ba4c4420789c92dc0e2aaeb3d5b64859ec728c
--- /dev/null
+++ b/facelib/detection/yolov5face/models/experimental.py
@@ -0,0 +1,45 @@
+# # This file contains experimental modules
+
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import Conv
+
+
+class CrossConv(nn.Module):
+ # Cross Convolution Downsample
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
+ # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, (1, k), (1, s))
+ self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class MixConv2d(nn.Module):
+ # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
+ super().__init__()
+ groups = len(k)
+ if equal_ch: # equal c_ per group
+ i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices
+ c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
+ else: # equal weight.numel() per group
+ b = [c2] + [0] * groups
+ a = np.eye(groups + 1, groups, k=-1)
+ a -= np.roll(a, 1, axis=1)
+ a *= np.array(k) ** 2
+ a[0] = 1
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
+
+ self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
+ self.bn = nn.BatchNorm2d(c2)
+ self.act = nn.LeakyReLU(0.1, inplace=True)
+
+ def forward(self, x):
+ return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
diff --git a/facelib/detection/yolov5face/models/yolo.py b/facelib/detection/yolov5face/models/yolo.py
new file mode 100644
index 0000000000000000000000000000000000000000..70845d972f0bcfd3632fcbac096b23e1b4d4d779
--- /dev/null
+++ b/facelib/detection/yolov5face/models/yolo.py
@@ -0,0 +1,235 @@
+import math
+from copy import deepcopy
+from pathlib import Path
+
+import torch
+import yaml # for torch hub
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import (
+ C3,
+ NMS,
+ SPP,
+ AutoShape,
+ Bottleneck,
+ BottleneckCSP,
+ Concat,
+ Conv,
+ DWConv,
+ Focus,
+ ShuffleV2Block,
+ StemBlock,
+)
+from facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d
+from facelib.detection.yolov5face.utils.autoanchor import check_anchor_order
+from facelib.detection.yolov5face.utils.general import make_divisible
+from facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn
+
+
+class Detect(nn.Module):
+ stride = None # strides computed during build
+ export = False # onnx export
+
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
+ super().__init__()
+ self.nc = nc # number of classes
+ self.no = nc + 5 + 10 # number of outputs per anchor
+
+ self.nl = len(anchors) # number of detection layers
+ self.na = len(anchors[0]) // 2 # number of anchors
+ self.grid = [torch.zeros(1)] * self.nl # init grid
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
+ self.register_buffer("anchors", a) # shape(nl,na,2)
+ self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
+
+ def forward(self, x):
+ z = [] # inference output
+ if self.export:
+ for i in range(self.nl):
+ x[i] = self.m[i](x[i])
+ return x
+ for i in range(self.nl):
+ x[i] = self.m[i](x[i]) # conv
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
+
+ if not self.training: # inference
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
+
+ y = torch.full_like(x[i], 0)
+ y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid()
+ y[..., 5:15] = x[i][..., 5:15]
+
+ y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
+
+ y[..., 5:7] = (
+ y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x1 y1
+ y[..., 7:9] = (
+ y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x2 y2
+ y[..., 9:11] = (
+ y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x3 y3
+ y[..., 11:13] = (
+ y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x4 y4
+ y[..., 13:15] = (
+ y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x5 y5
+
+ z.append(y.view(bs, -1, self.no))
+
+ return x if self.training else (torch.cat(z, 1), x)
+
+ @staticmethod
+ def _make_grid(nx=20, ny=20):
+ # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
+
+
+class Model(nn.Module):
+ def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes
+ super().__init__()
+ self.yaml_file = Path(cfg).name
+ with Path(cfg).open(encoding="utf8") as f:
+ self.yaml = yaml.safe_load(f) # model dict
+
+ # Define model
+ ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
+ if nc and nc != self.yaml["nc"]:
+ self.yaml["nc"] = nc # override yaml value
+
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
+ self.names = [str(i) for i in range(self.yaml["nc"])] # default names
+
+ # Build strides, anchors
+ m = self.model[-1] # Detect()
+ if isinstance(m, Detect):
+ s = 128 # 2x min stride
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
+ m.anchors /= m.stride.view(-1, 1, 1)
+ check_anchor_order(m)
+ self.stride = m.stride
+ self._initialize_biases() # only run once
+
+ def forward(self, x):
+ return self.forward_once(x) # single-scale inference, train
+
+ def forward_once(self, x):
+ y = [] # outputs
+ for m in self.model:
+ if m.f != -1: # if not from previous layer
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
+
+ x = m(x) # run
+ y.append(x if m.i in self.save else None) # save output
+
+ return x
+
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
+ # https://arxiv.org/abs/1708.02002 section 3.3
+ m = self.model[-1] # Detect() module
+ for mi, s in zip(m.m, m.stride): # from
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+ def _print_biases(self):
+ m = self.model[-1] # Detect() module
+ for mi in m.m: # from
+ b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
+ print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
+
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
+ print("Fusing layers... ")
+ for m in self.model.modules():
+ if isinstance(m, Conv) and hasattr(m, "bn"):
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
+ delattr(m, "bn") # remove batchnorm
+ m.forward = m.fuseforward # update forward
+ elif type(m) is nn.Upsample:
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
+ return self
+
+ def nms(self, mode=True): # add or remove NMS module
+ present = isinstance(self.model[-1], NMS) # last layer is NMS
+ if mode and not present:
+ print("Adding NMS... ")
+ m = NMS() # module
+ m.f = -1 # from
+ m.i = self.model[-1].i + 1 # index
+ self.model.add_module(name=str(m.i), module=m) # add
+ self.eval()
+ elif not mode and present:
+ print("Removing NMS... ")
+ self.model = self.model[:-1] # remove
+ return self
+
+ def autoshape(self): # add autoShape module
+ print("Adding autoShape... ")
+ m = AutoShape(self) # wrap model
+ copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes
+ return m
+
+
+def parse_model(d, ch): # model_dict, input_channels(3)
+ anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"]
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
+
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
+ for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
+ m = eval(m) if isinstance(m, str) else m # eval strings
+ for j, a in enumerate(args):
+ try:
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
+ except:
+ pass
+
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
+ if m in [
+ Conv,
+ Bottleneck,
+ SPP,
+ DWConv,
+ MixConv2d,
+ Focus,
+ CrossConv,
+ BottleneckCSP,
+ C3,
+ ShuffleV2Block,
+ StemBlock,
+ ]:
+ c1, c2 = ch[f], args[0]
+
+ c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
+
+ args = [c1, c2, *args[1:]]
+ if m in [BottleneckCSP, C3]:
+ args.insert(2, n)
+ n = 1
+ elif m is nn.BatchNorm2d:
+ args = [ch[f]]
+ elif m is Concat:
+ c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
+ elif m is Detect:
+ args.append([ch[x + 1] for x in f])
+ if isinstance(args[1], int): # number of anchors
+ args[1] = [list(range(args[1] * 2))] * len(f)
+ else:
+ c2 = ch[f]
+
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
+ t = str(m)[8:-2].replace("__main__.", "") # module type
+ np = sum(x.numel() for x in m_.parameters()) # number params
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
+ layers.append(m_)
+ ch.append(c2)
+ return nn.Sequential(*layers), sorted(save)
diff --git a/facelib/detection/yolov5face/models/yolov5l.yaml b/facelib/detection/yolov5face/models/yolov5l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0532b0e22fa7f59349b178146ffddcfdb368aba6
--- /dev/null
+++ b/facelib/detection/yolov5face/models/yolov5l.yaml
@@ -0,0 +1,47 @@
+# parameters
+nc: 1 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# anchors
+anchors:
+ - [4,5, 8,10, 13,16] # P3/8
+ - [23,29, 43,55, 73,105] # P4/16
+ - [146,217, 231,300, 335,433] # P5/32
+
+# YOLOv5 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2
+ [-1, 3, C3, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 2-P3/8
+ [-1, 9, C3, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 4-P4/16
+ [-1, 9, C3, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32
+ [-1, 1, SPP, [1024, [3,5,7]]],
+ [-1, 3, C3, [1024, False]], # 8
+ ]
+
+# YOLOv5 head
+head:
+ [[-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 5], 1, Concat, [1]], # cat backbone P4
+ [-1, 3, C3, [512, False]], # 12
+
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 3], 1, Concat, [1]], # cat backbone P3
+ [-1, 3, C3, [256, False]], # 16 (P3/8-small)
+
+ [-1, 1, Conv, [256, 3, 2]],
+ [[-1, 13], 1, Concat, [1]], # cat head P4
+ [-1, 3, C3, [512, False]], # 19 (P4/16-medium)
+
+ [-1, 1, Conv, [512, 3, 2]],
+ [[-1, 9], 1, Concat, [1]], # cat head P5
+ [-1, 3, C3, [1024, False]], # 22 (P5/32-large)
+
+ [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
+ ]
\ No newline at end of file
diff --git a/facelib/detection/yolov5face/models/yolov5n.yaml b/facelib/detection/yolov5face/models/yolov5n.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..caba6bed674aa2213b110f19e04eb352ffbeaf1e
--- /dev/null
+++ b/facelib/detection/yolov5face/models/yolov5n.yaml
@@ -0,0 +1,45 @@
+# parameters
+nc: 1 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# anchors
+anchors:
+ - [4,5, 8,10, 13,16] # P3/8
+ - [23,29, 43,55, 73,105] # P4/16
+ - [146,217, 231,300, 335,433] # P5/32
+
+# YOLOv5 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4
+ [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8
+ [-1, 3, ShuffleV2Block, [128, 1]], # 2
+ [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16
+ [-1, 7, ShuffleV2Block, [256, 1]], # 4
+ [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32
+ [-1, 3, ShuffleV2Block, [512, 1]], # 6
+ ]
+
+# YOLOv5 head
+head:
+ [[-1, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 4], 1, Concat, [1]], # cat backbone P4
+ [-1, 1, C3, [128, False]], # 10
+
+ [-1, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 2], 1, Concat, [1]], # cat backbone P3
+ [-1, 1, C3, [128, False]], # 14 (P3/8-small)
+
+ [-1, 1, Conv, [128, 3, 2]],
+ [[-1, 11], 1, Concat, [1]], # cat head P4
+ [-1, 1, C3, [128, False]], # 17 (P4/16-medium)
+
+ [-1, 1, Conv, [128, 3, 2]],
+ [[-1, 7], 1, Concat, [1]], # cat head P5
+ [-1, 1, C3, [128, False]], # 20 (P5/32-large)
+
+ [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
+ ]
diff --git a/facelib/detection/yolov5face/utils/__init__.py b/facelib/detection/yolov5face/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/facelib/detection/yolov5face/utils/autoanchor.py b/facelib/detection/yolov5face/utils/autoanchor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4eba3e94888709be7d2a7c7499fbcc1808b4a88
--- /dev/null
+++ b/facelib/detection/yolov5face/utils/autoanchor.py
@@ -0,0 +1,12 @@
+# Auto-anchor utils
+
+
+def check_anchor_order(m):
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
+ a = m.anchor_grid.prod(-1).view(-1) # anchor area
+ da = a[-1] - a[0] # delta a
+ ds = m.stride[-1] - m.stride[0] # delta s
+ if da.sign() != ds.sign(): # same order
+ print("Reversing anchor order")
+ m.anchors[:] = m.anchors.flip(0)
+ m.anchor_grid[:] = m.anchor_grid.flip(0)
diff --git a/facelib/detection/yolov5face/utils/datasets.py b/facelib/detection/yolov5face/utils/datasets.py
new file mode 100755
index 0000000000000000000000000000000000000000..e672b136f56fd6b05038e24377908361a54fe519
--- /dev/null
+++ b/facelib/detection/yolov5face/utils/datasets.py
@@ -0,0 +1,35 @@
+import cv2
+import numpy as np
+
+
+def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True):
+ # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
+ shape = img.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better test mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
+ elif scale_fill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return img, ratio, (dw, dh)
diff --git a/facelib/detection/yolov5face/utils/extract_ckpt.py b/facelib/detection/yolov5face/utils/extract_ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b8b631348f2d0cdea4e5a3594bb59f3e8f34a0f
--- /dev/null
+++ b/facelib/detection/yolov5face/utils/extract_ckpt.py
@@ -0,0 +1,5 @@
+import torch
+import sys
+sys.path.insert(0,'./facelib/detection/yolov5face')
+model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model']
+torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth')
\ No newline at end of file
diff --git a/facelib/detection/yolov5face/utils/general.py b/facelib/detection/yolov5face/utils/general.py
new file mode 100755
index 0000000000000000000000000000000000000000..1c8e14f56a107ec3a4269c382cfc5168ad780ffc
--- /dev/null
+++ b/facelib/detection/yolov5face/utils/general.py
@@ -0,0 +1,271 @@
+import math
+import time
+
+import numpy as np
+import torch
+import torchvision
+
+
+def check_img_size(img_size, s=32):
+ # Verify img_size is a multiple of stride s
+ new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
+ # if new_size != img_size:
+ # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}")
+ return new_size
+
+
+def make_divisible(x, divisor):
+ # Returns x evenly divisible by divisor
+ return math.ceil(x / divisor) * divisor
+
+
+def xyxy2xywh(x):
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
+ y[:, 2] = x[:, 2] - x[:, 0] # width
+ y[:, 3] = x[:, 3] - x[:, 1] # height
+ return y
+
+
+def xywh2xyxy(x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+
+def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2]] -= pad[0] # x padding
+ coords[:, [1, 3]] -= pad[1] # y padding
+ coords[:, :4] /= gain
+ clip_coords(coords, img0_shape)
+ return coords
+
+
+def clip_coords(boxes, img_shape):
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
+
+
+def box_iou(box1, box2):
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Arguments:
+ box1 (Tensor[N, 4])
+ box2 (Tensor[M, 4])
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+
+ def box_area(box):
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ area1 = box_area(box1.T)
+ area2 = box_area(box2.T)
+
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
+ return inter / (area1[:, None] + area2 - inter)
+
+
+def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+ """Performs Non-Maximum Suppression (NMS) on inference results
+ Returns:
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+ """
+
+ nc = prediction.shape[2] - 15 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Settings
+ # (pixels) maximum box width and height
+ max_wh = 4096
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ label = labels[xi]
+ v = torch.zeros((len(label), nc + 15), device=x.device)
+ v[:, :4] = label[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, landmarks, cls)
+ if multi_label:
+ i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 15:].max(1, keepdim=True)
+ x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # If none remain process next image
+ n = x.shape[0] # number of boxes
+ if not n:
+ continue
+
+ # Batched NMS
+ c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ break # time limit exceeded
+
+ return output
+
+
+def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+ """Performs Non-Maximum Suppression (NMS) on inference results
+
+ Returns:
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+ """
+
+ nc = prediction.shape[2] - 5 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Settings
+ # (pixels) maximum box width and height
+ max_wh = 4096
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ label_id = labels[xi]
+ v = torch.zeros((len(label_id), nc + 5), device=x.device)
+ v[:, :4] = label_id[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ if multi_label:
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 5:].max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ print(f"WARNING: NMS time limit {time_limit}s exceeded")
+ break # time limit exceeded
+
+ return output
+
+
+def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
+ coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
+ coords[:, :10] /= gain
+ coords[:, 0].clamp_(0, img0_shape[1]) # x1
+ coords[:, 1].clamp_(0, img0_shape[0]) # y1
+ coords[:, 2].clamp_(0, img0_shape[1]) # x2
+ coords[:, 3].clamp_(0, img0_shape[0]) # y2
+ coords[:, 4].clamp_(0, img0_shape[1]) # x3
+ coords[:, 5].clamp_(0, img0_shape[0]) # y3
+ coords[:, 6].clamp_(0, img0_shape[1]) # x4
+ coords[:, 7].clamp_(0, img0_shape[0]) # y4
+ coords[:, 8].clamp_(0, img0_shape[1]) # x5
+ coords[:, 9].clamp_(0, img0_shape[0]) # y5
+ return coords
diff --git a/facelib/detection/yolov5face/utils/torch_utils.py b/facelib/detection/yolov5face/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af2d06587b2d07b2eab199a8484380fde1de5c3c
--- /dev/null
+++ b/facelib/detection/yolov5face/utils/torch_utils.py
@@ -0,0 +1,40 @@
+import torch
+from torch import nn
+
+
+def fuse_conv_and_bn(conv, bn):
+ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
+ fusedconv = (
+ nn.Conv2d(
+ conv.in_channels,
+ conv.out_channels,
+ kernel_size=conv.kernel_size,
+ stride=conv.stride,
+ padding=conv.padding,
+ groups=conv.groups,
+ bias=True,
+ )
+ .requires_grad_(False)
+ .to(conv.weight.device)
+ )
+
+ # prepare filters
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
+
+ # prepare spatial bias
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+ return fusedconv
+
+
+def copy_attr(a, b, include=(), exclude=()):
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
+ for k, v in b.__dict__.items():
+ if (include and k not in include) or k.startswith("_") or k in exclude:
+ continue
+
+ setattr(a, k, v)
diff --git a/facelib/parsing/__init__.py b/facelib/parsing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90cfad00267f028ebf595fd5339180fc6f6f6385
--- /dev/null
+++ b/facelib/parsing/__init__.py
@@ -0,0 +1,23 @@
+import torch
+
+from facelib.utils import load_file_from_url
+from .bisenet import BiSeNet
+from .parsenet import ParseNet
+
+
+def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
+ if model_name == 'bisenet':
+ model = BiSeNet(num_class=19)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth'
+ elif model_name == 'parsenet':
+ model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
+ model.load_state_dict(load_net, strict=True)
+ model.eval()
+ model = model.to(device)
+ return model
diff --git a/facelib/parsing/bisenet.py b/facelib/parsing/bisenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3898cab76ae5876459cd4899c54cafa14234971d
--- /dev/null
+++ b/facelib/parsing/bisenet.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .resnet import ResNet18
+
+
+class ConvBNReLU(nn.Module):
+
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
+ super(ConvBNReLU, self).__init__()
+ self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
+ self.bn = nn.BatchNorm2d(out_chan)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(self.bn(x))
+ return x
+
+
+class BiSeNetOutput(nn.Module):
+
+ def __init__(self, in_chan, mid_chan, num_class):
+ super(BiSeNetOutput, self).__init__()
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
+ self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
+
+ def forward(self, x):
+ feat = self.conv(x)
+ out = self.conv_out(feat)
+ return out, feat
+
+
+class AttentionRefinementModule(nn.Module):
+
+ def __init__(self, in_chan, out_chan):
+ super(AttentionRefinementModule, self).__init__()
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
+ self.bn_atten = nn.BatchNorm2d(out_chan)
+ self.sigmoid_atten = nn.Sigmoid()
+
+ def forward(self, x):
+ feat = self.conv(x)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv_atten(atten)
+ atten = self.bn_atten(atten)
+ atten = self.sigmoid_atten(atten)
+ out = torch.mul(feat, atten)
+ return out
+
+
+class ContextPath(nn.Module):
+
+ def __init__(self):
+ super(ContextPath, self).__init__()
+ self.resnet = ResNet18()
+ self.arm16 = AttentionRefinementModule(256, 128)
+ self.arm32 = AttentionRefinementModule(512, 128)
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
+
+ def forward(self, x):
+ feat8, feat16, feat32 = self.resnet(x)
+ h8, w8 = feat8.size()[2:]
+ h16, w16 = feat16.size()[2:]
+ h32, w32 = feat32.size()[2:]
+
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
+ avg = self.conv_avg(avg)
+ avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
+
+ feat32_arm = self.arm32(feat32)
+ feat32_sum = feat32_arm + avg_up
+ feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
+ feat32_up = self.conv_head32(feat32_up)
+
+ feat16_arm = self.arm16(feat16)
+ feat16_sum = feat16_arm + feat32_up
+ feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
+ feat16_up = self.conv_head16(feat16_up)
+
+ return feat8, feat16_up, feat32_up # x8, x8, x16
+
+
+class FeatureFusionModule(nn.Module):
+
+ def __init__(self, in_chan, out_chan):
+ super(FeatureFusionModule, self).__init__()
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
+ self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
+ self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
+ self.relu = nn.ReLU(inplace=True)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, fsp, fcp):
+ fcat = torch.cat([fsp, fcp], dim=1)
+ feat = self.convblk(fcat)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv1(atten)
+ atten = self.relu(atten)
+ atten = self.conv2(atten)
+ atten = self.sigmoid(atten)
+ feat_atten = torch.mul(feat, atten)
+ feat_out = feat_atten + feat
+ return feat_out
+
+
+class BiSeNet(nn.Module):
+
+ def __init__(self, num_class):
+ super(BiSeNet, self).__init__()
+ self.cp = ContextPath()
+ self.ffm = FeatureFusionModule(256, 256)
+ self.conv_out = BiSeNetOutput(256, 256, num_class)
+ self.conv_out16 = BiSeNetOutput(128, 64, num_class)
+ self.conv_out32 = BiSeNetOutput(128, 64, num_class)
+
+ def forward(self, x, return_feat=False):
+ h, w = x.size()[2:]
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
+ feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
+
+ out, feat = self.conv_out(feat_fuse)
+ out16, feat16 = self.conv_out16(feat_cp8)
+ out32, feat32 = self.conv_out32(feat_cp16)
+
+ out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
+ out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
+ out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
+
+ if return_feat:
+ feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
+ feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
+ feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
+ return out, out16, out32, feat, feat16, feat32
+ else:
+ return out, out16, out32
diff --git a/facelib/parsing/parsenet.py b/facelib/parsing/parsenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e178ebe43a1ef666aaea0bc0faf629485c22a24f
--- /dev/null
+++ b/facelib/parsing/parsenet.py
@@ -0,0 +1,194 @@
+"""Modified from https://github.com/chaofengc/PSFRGAN
+"""
+import numpy as np
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class NormLayer(nn.Module):
+ """Normalization Layers.
+
+ Args:
+ channels: input channels, for batch norm and instance norm.
+ input_size: input shape without batch size, for layer norm.
+ """
+
+ def __init__(self, channels, normalize_shape=None, norm_type='bn'):
+ super(NormLayer, self).__init__()
+ norm_type = norm_type.lower()
+ self.norm_type = norm_type
+ if norm_type == 'bn':
+ self.norm = nn.BatchNorm2d(channels, affine=True)
+ elif norm_type == 'in':
+ self.norm = nn.InstanceNorm2d(channels, affine=False)
+ elif norm_type == 'gn':
+ self.norm = nn.GroupNorm(32, channels, affine=True)
+ elif norm_type == 'pixel':
+ self.norm = lambda x: F.normalize(x, p=2, dim=1)
+ elif norm_type == 'layer':
+ self.norm = nn.LayerNorm(normalize_shape)
+ elif norm_type == 'none':
+ self.norm = lambda x: x * 1.0
+ else:
+ assert 1 == 0, f'Norm type {norm_type} not support.'
+
+ def forward(self, x, ref=None):
+ if self.norm_type == 'spade':
+ return self.norm(x, ref)
+ else:
+ return self.norm(x)
+
+
+class ReluLayer(nn.Module):
+ """Relu Layer.
+
+ Args:
+ relu type: type of relu layer, candidates are
+ - ReLU
+ - LeakyReLU: default relu slope 0.2
+ - PRelu
+ - SELU
+ - none: direct pass
+ """
+
+ def __init__(self, channels, relu_type='relu'):
+ super(ReluLayer, self).__init__()
+ relu_type = relu_type.lower()
+ if relu_type == 'relu':
+ self.func = nn.ReLU(True)
+ elif relu_type == 'leakyrelu':
+ self.func = nn.LeakyReLU(0.2, inplace=True)
+ elif relu_type == 'prelu':
+ self.func = nn.PReLU(channels)
+ elif relu_type == 'selu':
+ self.func = nn.SELU(True)
+ elif relu_type == 'none':
+ self.func = lambda x: x * 1.0
+ else:
+ assert 1 == 0, f'Relu type {relu_type} not support.'
+
+ def forward(self, x):
+ return self.func(x)
+
+
+class ConvLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ scale='none',
+ norm_type='none',
+ relu_type='none',
+ use_pad=True,
+ bias=True):
+ super(ConvLayer, self).__init__()
+ self.use_pad = use_pad
+ self.norm_type = norm_type
+ if norm_type in ['bn']:
+ bias = False
+
+ stride = 2 if scale == 'down' else 1
+
+ self.scale_func = lambda x: x
+ if scale == 'up':
+ self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
+
+ self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
+
+ self.relu = ReluLayer(out_channels, relu_type)
+ self.norm = NormLayer(out_channels, norm_type=norm_type)
+
+ def forward(self, x):
+ out = self.scale_func(x)
+ if self.use_pad:
+ out = self.reflection_pad(out)
+ out = self.conv2d(out)
+ out = self.norm(out)
+ out = self.relu(out)
+ return out
+
+
+class ResidualBlock(nn.Module):
+ """
+ Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
+ """
+
+ def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
+ super(ResidualBlock, self).__init__()
+
+ if scale == 'none' and c_in == c_out:
+ self.shortcut_func = lambda x: x
+ else:
+ self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
+
+ scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
+ scale_conf = scale_config_dict[scale]
+
+ self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
+ self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
+
+ def forward(self, x):
+ identity = self.shortcut_func(x)
+
+ res = self.conv1(x)
+ res = self.conv2(res)
+ return identity + res
+
+
+class ParseNet(nn.Module):
+
+ def __init__(self,
+ in_size=128,
+ out_size=128,
+ min_feat_size=32,
+ base_ch=64,
+ parsing_ch=19,
+ res_depth=10,
+ relu_type='LeakyReLU',
+ norm_type='bn',
+ ch_range=[32, 256]):
+ super().__init__()
+ self.res_depth = res_depth
+ act_args = {'norm_type': norm_type, 'relu_type': relu_type}
+ min_ch, max_ch = ch_range
+
+ ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
+ min_feat_size = min(in_size, min_feat_size)
+
+ down_steps = int(np.log2(in_size // min_feat_size))
+ up_steps = int(np.log2(out_size // min_feat_size))
+
+ # =============== define encoder-body-decoder ====================
+ self.encoder = []
+ self.encoder.append(ConvLayer(3, base_ch, 3, 1))
+ head_ch = base_ch
+ for i in range(down_steps):
+ cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
+ self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
+ head_ch = head_ch * 2
+
+ self.body = []
+ for i in range(res_depth):
+ self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
+
+ self.decoder = []
+ for i in range(up_steps):
+ cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
+ self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
+ head_ch = head_ch // 2
+
+ self.encoder = nn.Sequential(*self.encoder)
+ self.body = nn.Sequential(*self.body)
+ self.decoder = nn.Sequential(*self.decoder)
+ self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
+ self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
+
+ def forward(self, x):
+ feat = self.encoder(x)
+ x = feat + self.body(feat)
+ x = self.decoder(x)
+ out_img = self.out_img_conv(x)
+ out_mask = self.out_mask_conv(x)
+ return out_mask, out_img
diff --git a/facelib/parsing/resnet.py b/facelib/parsing/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec8e82cf64469fb51be21ad5130217052addbda
--- /dev/null
+++ b/facelib/parsing/resnet.py
@@ -0,0 +1,69 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+
+ def __init__(self, in_chan, out_chan, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
+ self.bn1 = nn.BatchNorm2d(out_chan)
+ self.conv2 = conv3x3(out_chan, out_chan)
+ self.bn2 = nn.BatchNorm2d(out_chan)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ if in_chan != out_chan or stride != 1:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_chan),
+ )
+
+ def forward(self, x):
+ residual = self.conv1(x)
+ residual = F.relu(self.bn1(residual))
+ residual = self.conv2(residual)
+ residual = self.bn2(residual)
+
+ shortcut = x
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out = shortcut + residual
+ out = self.relu(out)
+ return out
+
+
+def create_layer_basic(in_chan, out_chan, bnum, stride=1):
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
+ for i in range(bnum - 1):
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
+ return nn.Sequential(*layers)
+
+
+class ResNet18(nn.Module):
+
+ def __init__(self):
+ super(ResNet18, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(self.bn1(x))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ feat8 = self.layer2(x) # 1/8
+ feat16 = self.layer3(feat8) # 1/16
+ feat32 = self.layer4(feat16) # 1/32
+ return feat8, feat16, feat32
diff --git a/facelib/utils/__init__.py b/facelib/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f03b1c2bafcd7759cb7e8722a0c6715f201a46dc
--- /dev/null
+++ b/facelib/utils/__init__.py
@@ -0,0 +1,7 @@
+from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
+from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir
+
+__all__ = [
+ 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url',
+ 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir'
+]
diff --git a/facelib/utils/face_restoration_helper.py b/facelib/utils/face_restoration_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3cf8456e59a79e4e82ef1b2325a2ff32a7604ce
--- /dev/null
+++ b/facelib/utils/face_restoration_helper.py
@@ -0,0 +1,705 @@
+import cv2
+import numpy as np
+import os
+import torch
+import pdb
+# import dlib
+from torchvision.transforms.functional import normalize
+
+from facelib.detection import init_detection_model
+from facelib.parsing import init_parsing_model
+from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
+from basicsr.utils.download_util import load_file_from_url
+from basicsr.utils.misc import get_device
+
+dlib_model_url = {
+ 'face_detector': 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
+ 'shape_predictor_5': 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
+}
+
+
+def get_largest_face(det_faces, h, w):
+
+ def get_location(val, length):
+ if val < 0:
+ return 0
+ elif val > length:
+ return length
+ else:
+ return val
+
+ face_areas = []
+ for det_face in det_faces:
+ left = get_location(det_face[0], w)
+ right = get_location(det_face[2], w)
+ top = get_location(det_face[1], h)
+ bottom = get_location(det_face[3], h)
+ face_area = (right - left) * (bottom - top)
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+ if center is not None:
+ center = np.array(center)
+ else:
+ center = np.array([w / 2, h / 2])
+ center_dist = []
+ for det_face in det_faces:
+ face_center = np.array(
+ [(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+ dist = np.linalg.norm(face_center - center)
+ center_dist.append(dist)
+ center_idx = center_dist.index(min(center_dist))
+ return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper(object):
+ """Helper for the face restoration pipeline (base class)."""
+
+ def __init__(self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = int(upscale_factor)
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1]
+ >= 1), 'crop ration only supports >=1'
+ self.face_size = (
+ int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+ self.det_model = det_model
+
+ if self.det_model == 'dlib':
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
+ [337.91089109, 488.38613861], [
+ 437.95049505, 493.51485149],
+ [513.58415842, 678.5049505]])
+ self.face_template = self.face_template / (1024 // face_size)
+ elif self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+ [201.26117, 371.41043], [313.08905, 371.15118]])
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * \
+ (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * \
+ (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.device = get_device()
+ else:
+ self.device = device
+
+ # init face detection model
+ if self.det_model == 'dlib':
+ self.face_detector, self.shape_predictor_5 = self.init_dlib(
+ dlib_model_url['face_detector'], dlib_model_url['shape_predictor_5'])
+ else:
+ self.face_detector = init_detection_model(
+ det_model, half=False, device=self.device)
+
+ # init face parsing model
+ self.use_parse = use_parse
+ self.face_parse = init_parsing_model(
+ model_name='parsenet', device=self.device)
+
+ def set_upscale_factor(self, upscale_factor):
+ self.upscale_factor = upscale_factor
+
+ def read_image(self, img):
+ """img can be image path or cv2 loaded image."""
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ if isinstance(img, str):
+ img = cv2.imread(img)
+
+ if np.max(img) > 256: # 16-bit image
+ img = img / 65535 * 255
+ if len(img.shape) == 2: # gray image
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif img.shape[2] == 4: # BGRA image with alpha channel
+ img = img[:, :, 0:3]
+
+ self.input_img = img
+ self.is_gray = is_gray(img, threshold=10)
+ if self.is_gray:
+ print('Grayscale input: True')
+
+ if min(self.input_img.shape[:2]) < 512:
+ f = 512.0/min(self.input_img.shape[:2])
+ self.input_img = cv2.resize(
+ self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+ def init_dlib(self, detection_path, landmark5_path):
+ """Initialize the dlib detectors and predictors."""
+ try:
+ import dlib
+ except ImportError:
+ print('Please install dlib by running:' 'conda install -c conda-forge dlib')
+ detection_path = load_file_from_url(
+ url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
+ landmark5_path = load_file_from_url(
+ url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
+ face_detector = dlib.cnn_face_detection_model_v1(detection_path)
+ shape_predictor_5 = dlib.shape_predictor(landmark5_path)
+ return face_detector, shape_predictor_5
+
+ def get_face_landmarks_5_dlib(self,
+ only_keep_largest=False,
+ scale=1):
+ det_faces = self.face_detector(self.input_img, scale)
+
+ if len(det_faces) == 0:
+ print('No face detected. Try to increase upsample_num_times.')
+ return 0
+ else:
+ if only_keep_largest:
+ print('Detect several faces and only keep the largest.')
+ face_areas = []
+ for i in range(len(det_faces)):
+ face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
+ det_faces[i].rect.bottom() - det_faces[i].rect.top())
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ self.det_faces = [det_faces[largest_idx]]
+ else:
+ self.det_faces = det_faces
+
+ if len(self.det_faces) == 0:
+ return 0
+
+ for face in self.det_faces:
+ shape = self.shape_predictor_5(self.input_img, face.rect)
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
+ self.all_landmarks_5.append(landmark)
+
+ return len(self.all_landmarks_5)
+
+ def get_face_landmarks_5(self,
+ only_keep_largest=False,
+ only_center_face=False,
+ resize=None,
+ blur_ratio=0.01,
+ eye_dist_threshold=None):
+ if self.det_model == 'dlib':
+ return self.get_face_landmarks_5_dlib(only_keep_largest)
+
+ if resize is None:
+ scale = 1
+ input_img = self.input_img
+ else:
+ h, w = self.input_img.shape[0:2]
+ scale = resize / min(h, w)
+ scale = max(1, scale) # always scale up
+ h, w = int(h * scale), int(w * scale)
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+ input_img = cv2.resize(
+ self.input_img, (w, h), interpolation=interp)
+
+ with torch.no_grad():
+ bboxes = self.face_detector.detect_faces(input_img)
+
+ if bboxes is None or bboxes.shape[0] == 0:
+ return 0
+ else:
+ bboxes = bboxes / scale
+
+ for bbox in bboxes:
+ # remove faces with too small eye distance: side faces or too small faces
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+ continue
+
+ if self.template_3points:
+ landmark = np.array([[bbox[i], bbox[i + 1]]
+ for i in range(5, 11, 2)])
+ else:
+ landmark = np.array([[bbox[i], bbox[i + 1]]
+ for i in range(5, 15, 2)])
+ self.all_landmarks_5.append(landmark)
+ self.det_faces.append(bbox[0:5])
+
+ if len(self.det_faces) == 0:
+ return 0
+ if only_keep_largest:
+ h, w, _ = self.input_img.shape
+ self.det_faces, largest_idx = get_largest_face(
+ self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+ elif only_center_face:
+ h, w, _ = self.input_img.shape
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+ # pad blurry images
+ if self.pad_blur:
+ self.pad_input_imgs = []
+ for landmarks in self.all_landmarks_5:
+ # get landmarks
+ eye_left = landmarks[0, :]
+ eye_right = landmarks[1, :]
+ eye_avg = (eye_left + eye_right) * 0.5
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1.5
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale,
+ np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+ border = max(int(np.rint(qsize * 0.1)), 3)
+
+ # get pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = [
+ max(-pad[0] + border, 1),
+ max(-pad[1] + border, 1),
+ max(pad[2] - self.input_img.shape[0] + border, 1),
+ max(pad[3] - self.input_img.shape[1] + border, 1)
+ ]
+
+ if max(pad) > 1:
+ # pad image
+ pad_img = np.pad(
+ self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # modify landmark coords
+ landmarks[:, 0] += pad[0]
+ landmarks[:, 1] += pad[1]
+ # blur pad images
+ h, w, _ = pad_img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * blur_ratio)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+ pad_img = pad_img.astype('float32')
+ pad_img += (blur_img - pad_img) * \
+ np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ pad_img += (np.median(pad_img, axis=(0, 1)) -
+ pad_img) * np.clip(mask, 0.0, 1.0)
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
+ self.pad_input_imgs.append(pad_img)
+ else:
+ self.pad_input_imgs.append(np.copy(self.input_img))
+
+ return len(self.all_landmarks_5)
+
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
+ """Align and warp faces with face template.
+ """
+ if self.pad_blur:
+ assert len(self.pad_input_imgs) == len(
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
+ for idx, landmark in enumerate(self.all_landmarks_5):
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(
+ landmark, self.face_template, method=cv2.LMEDS)[0]
+ self.affine_matrices.append(affine_matrix)
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+ if self.pad_blur:
+ input_img = self.pad_input_imgs[idx]
+ else:
+ input_img = self.input_img
+ # pdb.set_trace()
+ cropped_face = cv2.warpAffine(
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+ self.cropped_faces.append(cropped_face)
+ # save the cropped face
+ if save_cropped_path is not None:
+ path = os.path.splitext(save_cropped_path)[0]
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
+ imwrite(cropped_face, save_path)
+
+ def get_inverse_affine(self, save_inverse_affine_path=None):
+ """Get inverse affine matrix."""
+ for idx, affine_matrix in enumerate(self.affine_matrices):
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ inverse_affine *= self.upscale_factor
+ self.inverse_affine_matrices.append(inverse_affine)
+ # save inverse affine matrices
+ if save_inverse_affine_path is not None:
+ path, _ = os.path.splitext(save_inverse_affine_path)
+ save_path = f'{path}_{idx:02d}.pth'
+ torch.save(inverse_affine, save_path)
+
+ def add_restored_face(self, restored_face, input_face=None):
+ if self.is_gray:
+ # convert img into grayscale
+ restored_face = bgr2gray(restored_face)
+ if input_face is not None:
+ restored_face = adain_npy(
+ restored_face, input_face) # transfer the color
+ self.restored_faces.append(restored_face)
+
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+ h, w, _ = self.input_img.shape
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+ if upsample_img is None:
+ # simply resize the background
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+ upsample_img = cv2.resize(
+ self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+ else:
+ upsample_img = cv2.resize(
+ upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+ assert len(self.restored_faces) == len(
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
+
+ inv_mask_borders = []
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
+ if face_upsampler is not None:
+ restored_face = face_upsampler.enhance(
+ restored_face, outscale=self.upscale_factor)[0]
+ inverse_affine /= self.upscale_factor
+ inverse_affine[:, 2] *= self.upscale_factor
+ face_size = (
+ self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
+ else:
+ # Add an offset to inverse affine matrix, for more precise back alignment
+ if self.upscale_factor > 1:
+ extra_offset = 0.5 * self.upscale_factor
+ else:
+ extra_offset = 0
+ inverse_affine[:, 2] += extra_offset
+ face_size = self.face_size
+ inv_restored = cv2.warpAffine(
+ restored_face, inverse_affine, (w_up, h_up))
+
+ # if draw_box or not self.use_parse: # use square parse maps
+ # mask = np.ones(face_size, dtype=np.float32)
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # # remove the black borders
+ # inv_mask_erosion = cv2.erode(
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
+ # # add border
+ # if draw_box:
+ # h, w = face_size
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
+ # border = int(1400/np.sqrt(total_face_area))
+ # mask_border[border:h-border, border:w-border,:] = 0
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ # inv_mask_borders.append(inv_mask_border)
+ # if not self.use_parse:
+ # # compute the fusion edge based on the area of face
+ # w_edge = int(total_face_area**0.5) // 20
+ # erosion_radius = w_edge * 2
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ # blur_size = w_edge * 2
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
+ # upsample_img = upsample_img[:, :, None]
+ # inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # always use square mask
+ mask = np.ones(face_size, dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) # // 3
+ # add border
+ if draw_box:
+ h, w = face_size
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
+ border = int(1400/np.sqrt(total_face_area))
+ mask_border[border:h-border, border:w-border, :] = 0
+ inv_mask_border = cv2.warpAffine(
+ mask_border, inverse_affine, (w_up, h_up))
+ inv_mask_borders.append(inv_mask_border)
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones(
+ (erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(
+ inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
+ upsample_img = upsample_img[:, :, None]
+ inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # parse mask
+ if self.use_parse:
+ # inference
+ face_input = cv2.resize(
+ restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_input = img2tensor(face_input.astype(
+ 'float32') / 255., bgr2rgb=True, float32=True)
+ normalize(face_input, (0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5), inplace=True)
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
+ with torch.no_grad():
+ out = self.face_parse(face_input)[0]
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+ parse_mask = np.zeros(out.shape)
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+ for idx, color in enumerate(MASK_COLORMAP):
+ parse_mask[out == idx] = color
+ # blur the mask
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ # remove the black borders
+ thres = 10
+ parse_mask[:thres, :] = 0
+ parse_mask[-thres:, :] = 0
+ parse_mask[:, :thres] = 0
+ parse_mask[:, -thres:] = 0
+ parse_mask = parse_mask / 255.
+
+ parse_mask = cv2.resize(parse_mask, face_size)
+ parse_mask = cv2.warpAffine(
+ parse_mask, inverse_affine, (w_up, h_up), flags=3)
+ inv_soft_parse_mask = parse_mask[:, :, None]
+ # pasted_face = inv_restored
+ fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
+ inv_soft_mask = inv_soft_parse_mask * \
+ fuse_mask + inv_soft_mask*(1-fuse_mask)
+
+ # alpha channel
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4:
+ alpha = upsample_img[:, :, 3:]
+ upsample_img = inv_soft_mask * pasted_face + \
+ (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
+ else:
+ upsample_img = inv_soft_mask * pasted_face + \
+ (1 - inv_soft_mask) * upsample_img
+
+ if np.max(upsample_img) > 256: # 16-bit image
+ upsample_img = upsample_img.astype(np.uint16)
+ else:
+ upsample_img = upsample_img.astype(np.uint8)
+
+ # draw bounding box
+ if draw_box:
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+ img_color[:, :, 0] = 0
+ img_color[:, :, 1] = 255
+ img_color[:, :, 2] = 0
+ for inv_mask_border in inv_mask_borders:
+ upsample_img = inv_mask_border * img_color + \
+ (1 - inv_mask_border) * upsample_img
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+ if save_path is not None:
+ path = os.path.splitext(save_path)[0]
+ save_path = f'{path}.{self.save_ext}'
+ imwrite(upsample_img, save_path)
+ return upsample_img
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
+
+
+class FaceAligner(object):
+ def __init__(self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = int(upscale_factor)
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1]
+ >= 1), 'crop ration only supports >=1'
+ self.face_size = (
+ int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+ self.det_model = det_model
+
+ if self.det_model == 'dlib':
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
+ [337.91089109, 488.38613861], [
+ 437.95049505, 493.51485149],
+ [513.58415842, 678.5049505]])
+ self.face_template = self.face_template / (1024 // face_size)
+ elif self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+ [201.26117, 371.41043], [313.08905, 371.15118]])
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * \
+ (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * \
+ (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.device = get_device()
+ else:
+ self.device = device
+
+ def set_image(self, img):
+ self.input_img = img
+
+ def align_pair_face(self, img_lq, img_gt, landmarks):
+ img_lq = (img_lq[:, :, ::-1] * 255).round().astype(np.uint8)
+ img_gt = (img_gt[:, :, ::-1] * 255).round().astype(np.uint8)
+
+ self.set_image(img_gt)
+ img_lq, img_gt = self.align_warp_face(img_lq, img_gt, landmarks)
+ img_lq = img_lq[:, :, ::-1] / 255.0
+ img_gt = img_gt[:, :, ::-1] / 255.0
+ return img_lq, img_gt
+
+ def align_single_face(self, img, landmarks, border_mode='constant'):
+ """Align and warp faces with face template.
+ Suppose input images are Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ """
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+
+ img = (img[:, :, ::-1] * 255).round().astype(np.uint8)
+
+ affine_matrix = cv2.estimateAffinePartial2D(
+ landmarks, self.face_template, method=cv2.LMEDS)[0]
+ img = cv2.warpAffine(
+ img, affine_matrix, img.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+ img = img[:, :, ::-1] / 255.0
+ return img
+
+ def align_warp_face(self, img_lq, img_gt, landmarks, border_mode='constant'):
+ """Align and warp faces with face template.
+ Suppose input images are Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ """
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ scale = img_gt.shape[0] / img_lq.shape[0]
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+
+ affine_matrix = cv2.estimateAffinePartial2D(
+ landmarks, self.face_template, method=cv2.LMEDS)[0]
+ img_gt = cv2.warpAffine(
+ img_gt, affine_matrix, img_gt.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+
+ affine_matrix = cv2.estimateAffinePartial2D(
+ landmarks / scale, self.face_template / scale, method=cv2.LMEDS)[0]
+ img_lq = cv2.warpAffine(
+ img_lq, affine_matrix, img_lq.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+
+ return img_lq, img_gt
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
diff --git a/facelib/utils/face_restoration_helper_bak.py b/facelib/utils/face_restoration_helper_bak.py
new file mode 100644
index 0000000000000000000000000000000000000000..e704c2c534ee39ecd00eecabbefadd4a90593535
--- /dev/null
+++ b/facelib/utils/face_restoration_helper_bak.py
@@ -0,0 +1,705 @@
+import cv2
+import numpy as np
+import os
+import torch
+import pdb
+import dlib
+from torchvision.transforms.functional import normalize
+
+from facelib.detection import init_detection_model
+from facelib.parsing import init_parsing_model
+from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
+from basicsr.utils.download_util import load_file_from_url
+from basicsr.utils.misc import get_device
+
+dlib_model_url = {
+ 'face_detector': 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
+ 'shape_predictor_5': 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
+}
+
+
+def get_largest_face(det_faces, h, w):
+
+ def get_location(val, length):
+ if val < 0:
+ return 0
+ elif val > length:
+ return length
+ else:
+ return val
+
+ face_areas = []
+ for det_face in det_faces:
+ left = get_location(det_face[0], w)
+ right = get_location(det_face[2], w)
+ top = get_location(det_face[1], h)
+ bottom = get_location(det_face[3], h)
+ face_area = (right - left) * (bottom - top)
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+ if center is not None:
+ center = np.array(center)
+ else:
+ center = np.array([w / 2, h / 2])
+ center_dist = []
+ for det_face in det_faces:
+ face_center = np.array(
+ [(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+ dist = np.linalg.norm(face_center - center)
+ center_dist.append(dist)
+ center_idx = center_dist.index(min(center_dist))
+ return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper(object):
+ """Helper for the face restoration pipeline (base class)."""
+
+ def __init__(self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = int(upscale_factor)
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1]
+ >= 1), 'crop ration only supports >=1'
+ self.face_size = (
+ int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+ self.det_model = det_model
+
+ if self.det_model == 'dlib':
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
+ [337.91089109, 488.38613861], [
+ 437.95049505, 493.51485149],
+ [513.58415842, 678.5049505]])
+ self.face_template = self.face_template / (1024 // face_size)
+ elif self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+ [201.26117, 371.41043], [313.08905, 371.15118]])
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * \
+ (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * \
+ (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.device = get_device()
+ else:
+ self.device = device
+
+ # init face detection model
+ if self.det_model == 'dlib':
+ self.face_detector, self.shape_predictor_5 = self.init_dlib(
+ dlib_model_url['face_detector'], dlib_model_url['shape_predictor_5'])
+ else:
+ self.face_detector = init_detection_model(
+ det_model, half=False, device=self.device)
+
+ # init face parsing model
+ self.use_parse = use_parse
+ self.face_parse = init_parsing_model(
+ model_name='parsenet', device=self.device)
+
+ def set_upscale_factor(self, upscale_factor):
+ self.upscale_factor = upscale_factor
+
+ def read_image(self, img):
+ """img can be image path or cv2 loaded image."""
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ if isinstance(img, str):
+ img = cv2.imread(img)
+
+ if np.max(img) > 256: # 16-bit image
+ img = img / 65535 * 255
+ if len(img.shape) == 2: # gray image
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif img.shape[2] == 4: # BGRA image with alpha channel
+ img = img[:, :, 0:3]
+
+ self.input_img = img
+ self.is_gray = is_gray(img, threshold=10)
+ if self.is_gray:
+ print('Grayscale input: True')
+
+ if min(self.input_img.shape[:2]) < 512:
+ f = 512.0/min(self.input_img.shape[:2])
+ self.input_img = cv2.resize(
+ self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+ def init_dlib(self, detection_path, landmark5_path):
+ """Initialize the dlib detectors and predictors."""
+ try:
+ import dlib
+ except ImportError:
+ print('Please install dlib by running:' 'conda install -c conda-forge dlib')
+ detection_path = load_file_from_url(
+ url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
+ landmark5_path = load_file_from_url(
+ url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
+ face_detector = dlib.cnn_face_detection_model_v1(detection_path)
+ shape_predictor_5 = dlib.shape_predictor(landmark5_path)
+ return face_detector, shape_predictor_5
+
+ def get_face_landmarks_5_dlib(self,
+ only_keep_largest=False,
+ scale=1):
+ det_faces = self.face_detector(self.input_img, scale)
+
+ if len(det_faces) == 0:
+ print('No face detected. Try to increase upsample_num_times.')
+ return 0
+ else:
+ if only_keep_largest:
+ print('Detect several faces and only keep the largest.')
+ face_areas = []
+ for i in range(len(det_faces)):
+ face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
+ det_faces[i].rect.bottom() - det_faces[i].rect.top())
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ self.det_faces = [det_faces[largest_idx]]
+ else:
+ self.det_faces = det_faces
+
+ if len(self.det_faces) == 0:
+ return 0
+
+ for face in self.det_faces:
+ shape = self.shape_predictor_5(self.input_img, face.rect)
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
+ self.all_landmarks_5.append(landmark)
+
+ return len(self.all_landmarks_5)
+
+ def get_face_landmarks_5(self,
+ only_keep_largest=False,
+ only_center_face=False,
+ resize=None,
+ blur_ratio=0.01,
+ eye_dist_threshold=None):
+ if self.det_model == 'dlib':
+ return self.get_face_landmarks_5_dlib(only_keep_largest)
+
+ if resize is None:
+ scale = 1
+ input_img = self.input_img
+ else:
+ h, w = self.input_img.shape[0:2]
+ scale = resize / min(h, w)
+ scale = max(1, scale) # always scale up
+ h, w = int(h * scale), int(w * scale)
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+ input_img = cv2.resize(
+ self.input_img, (w, h), interpolation=interp)
+
+ with torch.no_grad():
+ bboxes = self.face_detector.detect_faces(input_img)
+
+ if bboxes is None or bboxes.shape[0] == 0:
+ return 0
+ else:
+ bboxes = bboxes / scale
+
+ for bbox in bboxes:
+ # remove faces with too small eye distance: side faces or too small faces
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+ continue
+
+ if self.template_3points:
+ landmark = np.array([[bbox[i], bbox[i + 1]]
+ for i in range(5, 11, 2)])
+ else:
+ landmark = np.array([[bbox[i], bbox[i + 1]]
+ for i in range(5, 15, 2)])
+ self.all_landmarks_5.append(landmark)
+ self.det_faces.append(bbox[0:5])
+
+ if len(self.det_faces) == 0:
+ return 0
+ if only_keep_largest:
+ h, w, _ = self.input_img.shape
+ self.det_faces, largest_idx = get_largest_face(
+ self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+ elif only_center_face:
+ h, w, _ = self.input_img.shape
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+ # pad blurry images
+ if self.pad_blur:
+ self.pad_input_imgs = []
+ for landmarks in self.all_landmarks_5:
+ # get landmarks
+ eye_left = landmarks[0, :]
+ eye_right = landmarks[1, :]
+ eye_avg = (eye_left + eye_right) * 0.5
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1.5
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale,
+ np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+ border = max(int(np.rint(qsize * 0.1)), 3)
+
+ # get pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = [
+ max(-pad[0] + border, 1),
+ max(-pad[1] + border, 1),
+ max(pad[2] - self.input_img.shape[0] + border, 1),
+ max(pad[3] - self.input_img.shape[1] + border, 1)
+ ]
+
+ if max(pad) > 1:
+ # pad image
+ pad_img = np.pad(
+ self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # modify landmark coords
+ landmarks[:, 0] += pad[0]
+ landmarks[:, 1] += pad[1]
+ # blur pad images
+ h, w, _ = pad_img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * blur_ratio)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+ pad_img = pad_img.astype('float32')
+ pad_img += (blur_img - pad_img) * \
+ np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ pad_img += (np.median(pad_img, axis=(0, 1)) -
+ pad_img) * np.clip(mask, 0.0, 1.0)
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
+ self.pad_input_imgs.append(pad_img)
+ else:
+ self.pad_input_imgs.append(np.copy(self.input_img))
+
+ return len(self.all_landmarks_5)
+
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
+ """Align and warp faces with face template.
+ """
+ if self.pad_blur:
+ assert len(self.pad_input_imgs) == len(
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
+ for idx, landmark in enumerate(self.all_landmarks_5):
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(
+ landmark, self.face_template, method=cv2.LMEDS)[0]
+ self.affine_matrices.append(affine_matrix)
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+ if self.pad_blur:
+ input_img = self.pad_input_imgs[idx]
+ else:
+ input_img = self.input_img
+ # pdb.set_trace()
+ cropped_face = cv2.warpAffine(
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+ self.cropped_faces.append(cropped_face)
+ # save the cropped face
+ if save_cropped_path is not None:
+ path = os.path.splitext(save_cropped_path)[0]
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
+ imwrite(cropped_face, save_path)
+
+ def get_inverse_affine(self, save_inverse_affine_path=None):
+ """Get inverse affine matrix."""
+ for idx, affine_matrix in enumerate(self.affine_matrices):
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ inverse_affine *= self.upscale_factor
+ self.inverse_affine_matrices.append(inverse_affine)
+ # save inverse affine matrices
+ if save_inverse_affine_path is not None:
+ path, _ = os.path.splitext(save_inverse_affine_path)
+ save_path = f'{path}_{idx:02d}.pth'
+ torch.save(inverse_affine, save_path)
+
+ def add_restored_face(self, restored_face, input_face=None):
+ if self.is_gray:
+ # convert img into grayscale
+ restored_face = bgr2gray(restored_face)
+ if input_face is not None:
+ restored_face = adain_npy(
+ restored_face, input_face) # transfer the color
+ self.restored_faces.append(restored_face)
+
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+ h, w, _ = self.input_img.shape
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+ if upsample_img is None:
+ # simply resize the background
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+ upsample_img = cv2.resize(
+ self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+ else:
+ upsample_img = cv2.resize(
+ upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+ assert len(self.restored_faces) == len(
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
+
+ inv_mask_borders = []
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
+ if face_upsampler is not None:
+ restored_face = face_upsampler.enhance(
+ restored_face, outscale=self.upscale_factor)[0]
+ inverse_affine /= self.upscale_factor
+ inverse_affine[:, 2] *= self.upscale_factor
+ face_size = (
+ self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
+ else:
+ # Add an offset to inverse affine matrix, for more precise back alignment
+ if self.upscale_factor > 1:
+ extra_offset = 0.5 * self.upscale_factor
+ else:
+ extra_offset = 0
+ inverse_affine[:, 2] += extra_offset
+ face_size = self.face_size
+ inv_restored = cv2.warpAffine(
+ restored_face, inverse_affine, (w_up, h_up))
+
+ # if draw_box or not self.use_parse: # use square parse maps
+ # mask = np.ones(face_size, dtype=np.float32)
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # # remove the black borders
+ # inv_mask_erosion = cv2.erode(
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
+ # # add border
+ # if draw_box:
+ # h, w = face_size
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
+ # border = int(1400/np.sqrt(total_face_area))
+ # mask_border[border:h-border, border:w-border,:] = 0
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ # inv_mask_borders.append(inv_mask_border)
+ # if not self.use_parse:
+ # # compute the fusion edge based on the area of face
+ # w_edge = int(total_face_area**0.5) // 20
+ # erosion_radius = w_edge * 2
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ # blur_size = w_edge * 2
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
+ # upsample_img = upsample_img[:, :, None]
+ # inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # always use square mask
+ mask = np.ones(face_size, dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) # // 3
+ # add border
+ if draw_box:
+ h, w = face_size
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
+ border = int(1400/np.sqrt(total_face_area))
+ mask_border[border:h-border, border:w-border, :] = 0
+ inv_mask_border = cv2.warpAffine(
+ mask_border, inverse_affine, (w_up, h_up))
+ inv_mask_borders.append(inv_mask_border)
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones(
+ (erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(
+ inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
+ upsample_img = upsample_img[:, :, None]
+ inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # parse mask
+ if self.use_parse:
+ # inference
+ face_input = cv2.resize(
+ restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_input = img2tensor(face_input.astype(
+ 'float32') / 255., bgr2rgb=True, float32=True)
+ normalize(face_input, (0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5), inplace=True)
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
+ with torch.no_grad():
+ out = self.face_parse(face_input)[0]
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+ parse_mask = np.zeros(out.shape)
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+ for idx, color in enumerate(MASK_COLORMAP):
+ parse_mask[out == idx] = color
+ # blur the mask
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ # remove the black borders
+ thres = 10
+ parse_mask[:thres, :] = 0
+ parse_mask[-thres:, :] = 0
+ parse_mask[:, :thres] = 0
+ parse_mask[:, -thres:] = 0
+ parse_mask = parse_mask / 255.
+
+ parse_mask = cv2.resize(parse_mask, face_size)
+ parse_mask = cv2.warpAffine(
+ parse_mask, inverse_affine, (w_up, h_up), flags=3)
+ inv_soft_parse_mask = parse_mask[:, :, None]
+ # pasted_face = inv_restored
+ fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
+ inv_soft_mask = inv_soft_parse_mask * \
+ fuse_mask + inv_soft_mask*(1-fuse_mask)
+
+ # alpha channel
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4:
+ alpha = upsample_img[:, :, 3:]
+ upsample_img = inv_soft_mask * pasted_face + \
+ (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
+ else:
+ upsample_img = inv_soft_mask * pasted_face + \
+ (1 - inv_soft_mask) * upsample_img
+
+ if np.max(upsample_img) > 256: # 16-bit image
+ upsample_img = upsample_img.astype(np.uint16)
+ else:
+ upsample_img = upsample_img.astype(np.uint8)
+
+ # draw bounding box
+ if draw_box:
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+ img_color[:, :, 0] = 0
+ img_color[:, :, 1] = 255
+ img_color[:, :, 2] = 0
+ for inv_mask_border in inv_mask_borders:
+ upsample_img = inv_mask_border * img_color + \
+ (1 - inv_mask_border) * upsample_img
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+ if save_path is not None:
+ path = os.path.splitext(save_path)[0]
+ save_path = f'{path}.{self.save_ext}'
+ imwrite(upsample_img, save_path)
+ return upsample_img
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
+
+
+class FaceAligner(object):
+ def __init__(self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = int(upscale_factor)
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1]
+ >= 1), 'crop ration only supports >=1'
+ self.face_size = (
+ int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+ self.det_model = det_model
+
+ if self.det_model == 'dlib':
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
+ [337.91089109, 488.38613861], [
+ 437.95049505, 493.51485149],
+ [513.58415842, 678.5049505]])
+ self.face_template = self.face_template / (1024 // face_size)
+ elif self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+ [201.26117, 371.41043], [313.08905, 371.15118]])
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * \
+ (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * \
+ (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.device = get_device()
+ else:
+ self.device = device
+
+ def set_image(self, img):
+ self.input_img = img
+
+ def align_pair_face(self, img_lq, img_gt, landmarks):
+ img_lq = (img_lq[:, :, ::-1] * 255).round().astype(np.uint8)
+ img_gt = (img_gt[:, :, ::-1] * 255).round().astype(np.uint8)
+
+ self.set_image(img_gt)
+ img_lq, img_gt = self.align_warp_face(img_lq, img_gt, landmarks)
+ img_lq = img_lq[:, :, ::-1] / 255.0
+ img_gt = img_gt[:, :, ::-1] / 255.0
+ return img_lq, img_gt
+
+ def align_single_face(self, img, landmarks, border_mode='constant'):
+ """Align and warp faces with face template.
+ Suppose input images are Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ """
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+
+ img = (img[:, :, ::-1] * 255).round().astype(np.uint8)
+
+ affine_matrix = cv2.estimateAffinePartial2D(
+ landmarks, self.face_template, method=cv2.LMEDS)[0]
+ img = cv2.warpAffine(
+ img, affine_matrix, img.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+ img = img[:, :, ::-1] / 255.0
+ return img
+
+ def align_warp_face(self, img_lq, img_gt, landmarks, border_mode='constant'):
+ """Align and warp faces with face template.
+ Suppose input images are Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ """
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ scale = img_gt.shape[0] / img_lq.shape[0]
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+
+ affine_matrix = cv2.estimateAffinePartial2D(
+ landmarks, self.face_template, method=cv2.LMEDS)[0]
+ img_gt = cv2.warpAffine(
+ img_gt, affine_matrix, img_gt.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+
+ affine_matrix = cv2.estimateAffinePartial2D(
+ landmarks / scale, self.face_template / scale, method=cv2.LMEDS)[0]
+ img_lq = cv2.warpAffine(
+ img_lq, affine_matrix, img_lq.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+
+ return img_lq, img_gt
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
diff --git a/facelib/utils/face_utils.py b/facelib/utils/face_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1474a2a4419b6b62fab8a919ef805b802556464
--- /dev/null
+++ b/facelib/utils/face_utils.py
@@ -0,0 +1,248 @@
+import cv2
+import numpy as np
+import torch
+
+
+def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
+ left, top, right, bot = bbox
+ width = right - left
+ height = bot - top
+
+ if preserve_aspect:
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
+ else:
+ width_increase = height_increase = increase_area
+ left = int(left - width_increase * width)
+ top = int(top - height_increase * height)
+ right = int(right + width_increase * width)
+ bot = int(bot + height_increase * height)
+ return (left, top, right, bot)
+
+
+def get_valid_bboxes(bboxes, h, w):
+ left = max(bboxes[0], 0)
+ top = max(bboxes[1], 0)
+ right = min(bboxes[2], w)
+ bottom = min(bboxes[3], h)
+ return (left, top, right, bottom)
+
+
+def align_crop_face_landmarks(img,
+ landmarks,
+ output_size,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=False,
+ shrink_ratio=(1, 1)):
+ """Align and crop face with landmarks.
+
+ The output_size and transform_size are based on width. The height is
+ adjusted based on shrink_ratio_h/shring_ration_w.
+
+ Modified from:
+ https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
+
+ Args:
+ img (Numpy array): Input image.
+ landmarks (Numpy array): 5 or 68 or 98 landmarks.
+ output_size (int): Output face size.
+ transform_size (ing): Transform size. Usually the four time of
+ output_size.
+ enable_padding (float): Default: True.
+ shrink_ratio (float | tuple[float] | list[float]): Shring the whole
+ face for height and width (crop larger area). Default: (1, 1).
+
+ Returns:
+ (Numpy array): Cropped face.
+ """
+ lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
+
+ if isinstance(shrink_ratio, (float, int)):
+ shrink_ratio = (shrink_ratio, shrink_ratio)
+ if transform_size is None:
+ transform_size = output_size * 4
+
+ # Parse landmarks
+ lm = np.array(landmarks)
+ if lm.shape[0] == 5 and lm_type == 'retinaface_5':
+ eye_left = lm[0]
+ eye_right = lm[1]
+ mouth_avg = (lm[3] + lm[4]) * 0.5
+ elif lm.shape[0] == 5 and lm_type == 'dlib_5':
+ lm_eye_left = lm[2:4]
+ lm_eye_right = lm[0:2]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = lm[4]
+ elif lm.shape[0] == 68:
+ lm_eye_left = lm[36:42]
+ lm_eye_right = lm[42:48]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[48] + lm[54]) * 0.5
+ elif lm.shape[0] == 98:
+ lm_eye_left = lm[60:68]
+ lm_eye_right = lm[68:76]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[76] + lm[82]) * 0.5
+
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1 # TODO: you can edit it to get larger rect
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ x *= shrink_ratio[1] # width
+ y *= shrink_ratio[0] # height
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+
+ quad_ori = np.copy(quad)
+ # Shrink, for large face
+ # TODO: do we really need shrink
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ h, w = img.shape[0:2]
+ rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
+ img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop
+ h, w = img.shape[0:2]
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
+ if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
+ img = img[crop[1]:crop[3], crop[0]:crop[2], :]
+ quad -= crop[0:2]
+
+ # Pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ h, w = img.shape[0:2]
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w = img.shape[0:2]
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * 0.02)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
+
+ img = img.astype('float32')
+ img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = np.clip(img, 0, 255) # float32, [0, 255]
+ quad += pad[:2]
+
+ # Transform use cv2
+ h_ratio = shrink_ratio[0] / shrink_ratio[1]
+ dst_h, dst_w = int(transform_size * h_ratio), transform_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
+ cropped_face = cv2.warpAffine(
+ img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
+
+ if output_size < transform_size:
+ cropped_face = cv2.resize(
+ cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
+
+ if return_inverse_affine:
+ dst_h, dst_w = int(output_size * h_ratio), output_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(
+ quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ else:
+ inverse_affine = None
+ return cropped_face, inverse_affine
+
+
+def paste_face_back(img, face, inverse_affine):
+ h, w = img.shape[0:2]
+ face_h, face_w = face.shape[0:2]
+ inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
+ mask = np.ones((face_h, face_w, 3), dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) // 3
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
+ # float32, [0, 255]
+ return img
+
+
+if __name__ == '__main__':
+ import os
+
+ from facelib.detection import init_detection_model
+ from facelib.utils.face_restoration_helper import get_largest_face
+
+ img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
+ img_name = os.splitext(os.path.basename(img_path))[0]
+
+ # initialize model
+ det_net = init_detection_model('retinaface_resnet50', half=False)
+ img_ori = cv2.imread(img_path)
+ h, w = img_ori.shape[0:2]
+ # if larger than 800, scale it
+ scale = max(h / 800, w / 800)
+ if scale > 1:
+ img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
+
+ with torch.no_grad():
+ bboxes = det_net.detect_faces(img, 0.97)
+ if scale > 1:
+ bboxes *= scale # the score is incorrect
+ bboxes = get_largest_face(bboxes, h, w)[0]
+
+ landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
+
+ cropped_face, inverse_affine = align_crop_face_landmarks(
+ img_ori,
+ landmarks,
+ output_size=512,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=True,
+ shrink_ratio=(1, 1))
+
+ cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
+ img = paste_face_back(img_ori, cropped_face, inverse_affine)
+ cv2.imwrite(f'tmp/{img_name}_back.png', img)
diff --git a/facelib/utils/misc.py b/facelib/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..1875579294e15e20abccd9d719a7ca4fe6d43cd4
--- /dev/null
+++ b/facelib/utils/misc.py
@@ -0,0 +1,202 @@
+import cv2
+import os
+import os.path as osp
+import numpy as np
+from PIL import Image
+import torch
+from torch.hub import download_url_to_file, get_dir
+from urllib.parse import urlparse
+# from basicsr.utils.download_util import download_file_from_google_drive
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+def download_pretrained_models(file_ids, save_path_root):
+ import gdown
+
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_id in file_ids.items():
+ file_url = 'https://drive.google.com/uc?id='+file_id
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
+ if osp.exists(save_path):
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+ if user_response.lower() == 'y':
+ print(f'Covering {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+ elif user_response.lower() == 'n':
+ print(f'Skipping {file_name}')
+ else:
+ raise ValueError('Wrong input. Only accepts Y/N.')
+ else:
+ print(f'Downloading {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ """
+ if model_dir is None:
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def is_gray(img, threshold=10):
+ img = Image.fromarray(img)
+ if len(img.getbands()) == 1:
+ return True
+ img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
+ img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
+ img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
+ diff1 = (img1 - img2).var()
+ diff2 = (img2 - img3).var()
+ diff3 = (img3 - img1).var()
+ diff_sum = (diff1 + diff2 + diff3) / 3.0
+ if diff_sum <= threshold:
+ return True
+ else:
+ return False
+
+def rgb2gray(img, out_channel=3):
+ r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
+ if out_channel == 3:
+ gray = gray[:,:,np.newaxis].repeat(3, axis=2)
+ return gray
+
+def bgr2gray(img, out_channel=3):
+ b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
+ if out_channel == 3:
+ gray = gray[:,:,np.newaxis].repeat(3, axis=2)
+ return gray
+
+
+def calc_mean_std(feat, eps=1e-5):
+ """
+ Args:
+ feat (numpy): 3D [w h c]s
+ """
+ size = feat.shape
+ assert len(size) == 3, 'The input feature should be 3D tensor.'
+ c = size[2]
+ feat_var = feat.reshape(-1, c).var(axis=0) + eps
+ feat_std = np.sqrt(feat_var).reshape(1, 1, c)
+ feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
+ return feat_mean, feat_std
+
+
+def adain_npy(content_feat, style_feat):
+ """Adaptive instance normalization for numpy.
+
+ Args:
+ content_feat (numpy): The input feature.
+ style_feat (numpy): The reference feature.
+ """
+ size = content_feat.shape
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
+ return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
\ No newline at end of file
diff --git a/hugging_face/app.py b/hugging_face/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22994fc3443b508e68d12cde22786247e886f5d
--- /dev/null
+++ b/hugging_face/app.py
@@ -0,0 +1,266 @@
+import os
+import cv2
+import argparse
+import glob
+import torch
+import numpy as np
+from tqdm import tqdm
+from torchvision.transforms.functional import normalize
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.utils.download_util import load_file_from_url
+from basicsr.utils.misc import gpu_is_available, get_device
+from scipy.ndimage import gaussian_filter1d
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+from facelib.utils.misc import is_gray
+from basicsr.utils.video_util import VideoReader, VideoWriter
+from basicsr.utils.registry import ARCH_REGISTRY
+import gradio as gr
+from torch.hub import download_url_to_file
+
+title = r"""KEEP: Kalman-Inspired Feature Propagation for Video Face Super-Resolution
"""
+
+description = r"""
+Official Gradio demo for Kalman-Inspired FEaturE Propagation for Video Face Super-Resolution (ECCV 2024).
+🔥 KEEP is a robust video face super-resolution algorithm.
+🤗 Try to drop your own face video, and get the restored results!
+"""
+
+post_article = r"""
+If you found KEEP helpful, please consider ⭐ the Github Repo. Thanks!
+[](https://github.com/jnjaby/KEEP)
+---
+📝 **Citation**
+
+If our work is useful for your research, please consider citing:
+```bibtex
+@InProceedings{feng2024keep,
+ title = {Kalman-Inspired FEaturE Propagation for Video Face Super-Resolution},
+ author = {Feng, Ruicheng and Li, Chongyi and Loy, Chen Change},
+ booktitle = {European Conference on Computer Vision (ECCV)},
+ year = {2024}
+}
+```
+
+📋 **License**
+
+This project is licensed under S-Lab License 1.0.
+Redistribution and use for non-commercial purposes should follow this license.
+
+📧 **Contact**
+
+If you have any questions, please feel free to reach out via ruicheng002@ntu.edu.sg.
+"""
+
+
+
+def interpolate_sequence(sequence):
+ interpolated_sequence = np.copy(sequence)
+ missing_indices = np.isnan(sequence)
+ if np.any(missing_indices):
+ valid_indices = ~missing_indices
+ x = np.arange(len(sequence))
+ interpolated_sequence[missing_indices] = np.interp(x[missing_indices], x[valid_indices], sequence[valid_indices])
+ return interpolated_sequence
+
+def set_realesrgan():
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from basicsr.utils.realesrgan_utils import RealESRGANer
+ use_half = False
+ if torch.cuda.is_available():
+ no_half_gpu_list = ['1650', '1660']
+ if not any(gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list):
+ use_half = True
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ upsampler = RealESRGANer(scale=2, model_path="https://github.com/jnjaby/KEEP/releases/download/v0.1.0/RealESRGAN_x2plus.pth", model=model, tile=400, tile_pad=40, pre_pad=0, half=use_half)
+ if not gpu_is_available():
+ import warnings
+ warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA. The unoptimized RealESRGAN is slow on CPU.', category=RuntimeWarning)
+ return upsampler
+
+def process_video(input_video, draw_box, bg_enhancement):
+ device = get_device()
+ args = argparse.Namespace(
+ input_path=input_video,
+ upscale=1,
+ max_length=20,
+ has_aligned=False,
+ only_center_face=True,
+ draw_box=draw_box,
+ detection_model='retinaface_resnet50',
+ bg_enhancement=bg_enhancement,
+ face_upsample=False,
+ bg_tile=400,
+ suffix=None,
+ save_video_fps=None,
+ model_type='KEEP'
+ )
+
+ output_dir = './results/'
+ os.makedirs(output_dir, exist_ok=True)
+
+ model_configs = {
+ 'KEEP': {
+ 'architecture': {
+ 'img_size': 512, 'emb_dim': 256, 'dim_embd': 512, 'n_head': 8, 'n_layers': 9,
+ 'codebook_size': 1024, 'cft_list': ['16', '32', '64'], 'kalman_attn_head_dim': 48,
+ 'num_uncertainty_layers': 3, 'cfa_list': ['16', '32'], 'cfa_nhead': 4, 'cfa_dim': 256, 'cond': 1
+ },
+ 'checkpoint_dir': '../weights/KEEP',
+ 'checkpoint_url': 'https://github.com/jnjaby/KEEP/releases/download/v1.0.0/KEEP-b76feb75.pth'
+ },
+ }
+ if args.bg_enhancement:
+ bg_upsampler = set_realesrgan()
+ else:
+ bg_upsampler = None
+ if args.face_upsample:
+ face_upsampler = bg_upsampler if bg_upsampler is not None else set_realesrgan()
+ else:
+ face_upsampler = None
+
+ if args.model_type not in model_configs:
+ raise ValueError(f"Unknown model type: {args.model_type}. Available options: {list(model_configs.keys())}")
+ config = model_configs[args.model_type]
+ net = ARCH_REGISTRY.get('KEEP')(**config['architecture']).to(device)
+ ckpt_path = load_file_from_url(url=config['checkpoint_url'], model_dir=config['checkpoint_dir'], progress=True, file_name=None)
+ checkpoint = torch.load(ckpt_path, weights_only=True)
+ net.load_state_dict(checkpoint['params_ema'])
+ net.eval()
+ if not args.has_aligned:
+ print(f'Face detection model: {args.detection_model}')
+ if bg_upsampler is not None:
+ print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
+ else:
+ print(f'Background upsampling: False, Face upsampling: {args.face_upsample}')
+ face_helper = FaceRestoreHelper(args.upscale, face_size=512, crop_ratio=(1, 1), det_model=args.detection_model, save_ext='png', use_parse=True, device=device)
+
+ # Reading the input video.
+ input_img_list = []
+ if args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')):
+ vidreader = VideoReader(args.input_path)
+ image = vidreader.get_frame()
+ while image is not None:
+ input_img_list.append(image)
+ image = vidreader.get_frame()
+ fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
+ vidreader.close()
+ clip_name = os.path.basename(args.input_path)[:-4]
+ else:
+ raise TypeError(f'Unrecognized type of input video {args.input_path}.')
+ if len(input_img_list) == 0:
+ raise FileNotFoundError('No input image/video is found...')
+
+ print('Detecting keypoints and smooth alignment ...')
+ if not args.has_aligned:
+ raw_landmarks = []
+ for i, img in enumerate(input_img_list):
+ face_helper.clean_all()
+ face_helper.read_image(img)
+ num_det_faces = face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5, only_keep_largest=True)
+ if num_det_faces == 1:
+ raw_landmarks.append(face_helper.all_landmarks_5[0].reshape((10,)))
+ elif num_det_faces == 0:
+ raw_landmarks.append(np.array([np.nan]*10))
+ raw_landmarks = np.array(raw_landmarks)
+ for i in range(10):
+ raw_landmarks[:, i] = interpolate_sequence(raw_landmarks[:, i])
+ video_length = len(input_img_list)
+ avg_landmarks = gaussian_filter1d(raw_landmarks, 5, axis=0).reshape(video_length, 5, 2)
+ cropped_faces = []
+ for i, img in enumerate(input_img_list):
+ face_helper.clean_all()
+ face_helper.read_image(img)
+ face_helper.all_landmarks_5 = [avg_landmarks[i]]
+ face_helper.align_warp_face()
+ cropped_face_t = img2tensor(face_helper.cropped_faces[0] / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_faces.append(cropped_face_t)
+ cropped_faces = torch.stack(cropped_faces, dim=0).unsqueeze(0).to(device)
+ print('Restoring faces ...')
+ with torch.no_grad():
+ video_length = cropped_faces.shape[1]
+ output = []
+ for start_idx in range(0, video_length, args.max_length):
+ end_idx = min(start_idx + args.max_length, video_length)
+ if end_idx - start_idx == 1:
+ output.append(net(cropped_faces[:, [start_idx, start_idx], ...], need_upscale=False)[:, 0:1, ...])
+ else:
+ output.append(net(cropped_faces[:, start_idx:end_idx, ...], need_upscale=False))
+ output = torch.cat(output, dim=1).squeeze(0)
+ assert output.shape[0] == video_length, "Different number of frames"
+ restored_faces = [tensor2img(x, rgb2bgr=True, min_max=(-1, 1)) for x in output]
+ del output
+ torch.cuda.empty_cache()
+ print('Pasting faces back ...')
+
+ restored_frames = []
+ for i, img in enumerate(input_img_list):
+ face_helper.clean_all()
+ if args.has_aligned:
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_helper.is_gray = is_gray(img, threshold=10)
+ if face_helper.is_gray:
+ print('Grayscale input: True')
+ face_helper.cropped_faces = [img]
+ else:
+ face_helper.read_image(img)
+ face_helper.all_landmarks_5 = [avg_landmarks[i]]
+ face_helper.align_warp_face()
+ face_helper.add_restored_face(restored_faces[i].astype('uint8'))
+ if not args.has_aligned:
+ if bg_upsampler is not None:
+ bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
+ else:
+ bg_img = None
+ face_helper.get_inverse_affine(None)
+ if args.face_upsample and face_upsampler is not None:
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
+ else:
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
+
+ restored_frames.append(restored_img)
+
+ # Saving the output video.
+ print('Saving video ...')
+ height, width = restored_frames[0].shape[:2]
+ save_restore_path = os.path.join(output_dir, f'{clip_name}.mp4')
+ vidwriter = VideoWriter(save_restore_path, height, width, fps)
+ for f in restored_frames:
+ vidwriter.write_frame(f)
+ vidwriter.close()
+ print(f'All results are saved in {save_restore_path}.')
+ return save_restore_path
+
+# Downloading necessary models and sample videos.
+sample_videos_dir = os.path.join("test_sample/")
+os.makedirs(sample_videos_dir, exist_ok=True)
+download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_1.mp4", os.path.join(sample_videos_dir, "real_1.mp4"))
+download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_2.mp4", os.path.join(sample_videos_dir, "real_2.mp4"))
+download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_3.mp4", os.path.join(sample_videos_dir, "real_3.mp4"))
+download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_4.mp4", os.path.join(sample_videos_dir, "real_4.mp4"))
+
+model_dir = os.path.join("../weights/KEEP")
+_ = load_file_from_url(url='https://github.com/jnjaby/KEEP/releases/download/v1.0.0/KEEP-b76feb75.pth', model_dir=model_dir, progress=True, file_name=None)
+
+# Launching the Gradio interface.
+demo = gr.Interface(
+ fn=process_video,
+ title=title,
+ description=description,
+ inputs=[
+ gr.Video(label="Input Video"),
+ gr.Checkbox(label="Draw Box", value=False),
+ gr.Checkbox(label="Background Enhancement", value=False),
+ ],
+ outputs=gr.Video(label="Processed Video"),
+ examples=[
+ [os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_1.mp4"), True, False],
+ [os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_2.mp4"), True, False],
+ [os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_3.mp4"), True, False],
+ [os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_4.mp4"), True, False],
+ ],
+ article=post_article
+)
+
+
+demo.launch(share=True)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b517e8eddfd684c4da5b790e7277db9b9ac983bf
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,24 @@
+addict
+future
+lmdb
+numpy
+opencv-python
+Pillow
+pyyaml
+requests
+scikit-image
+scipy
+tb-nightly
+torch>=1.7.1
+torchvision
+tqdm
+yapf
+lpips
+gdown # supports downloading the large file from Google Drive
+diffusers==0.11.0
+einops
+huggingface_hub==0.20.2
+ffmpeg-python==0.2.0
+av
+cmake # for dlib
+dlib