sczhou commited on
Commit
4f25e99
1 Parent(s): 0334511

add codeformer code.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CodeFormer +0 -1
  2. CodeFormer/.gitignore +129 -0
  3. CodeFormer/README.md +123 -0
  4. CodeFormer/assets/CodeFormer_logo.png +0 -0
  5. CodeFormer/assets/color_enhancement_result1.png +0 -0
  6. CodeFormer/assets/color_enhancement_result2.png +0 -0
  7. CodeFormer/assets/inpainting_result1.png +0 -0
  8. CodeFormer/assets/inpainting_result2.png +0 -0
  9. CodeFormer/assets/network.jpg +0 -0
  10. CodeFormer/assets/restoration_result1.png +0 -0
  11. CodeFormer/assets/restoration_result2.png +0 -0
  12. CodeFormer/assets/restoration_result3.png +0 -0
  13. CodeFormer/assets/restoration_result4.png +0 -0
  14. CodeFormer/basicsr/VERSION +1 -0
  15. CodeFormer/basicsr/__init__.py +11 -0
  16. CodeFormer/basicsr/archs/__init__.py +25 -0
  17. CodeFormer/basicsr/archs/arcface_arch.py +245 -0
  18. CodeFormer/basicsr/archs/arch_util.py +318 -0
  19. CodeFormer/basicsr/archs/codeformer_arch.py +276 -0
  20. CodeFormer/basicsr/archs/rrdbnet_arch.py +119 -0
  21. CodeFormer/basicsr/archs/vgg_arch.py +161 -0
  22. CodeFormer/basicsr/archs/vqgan_arch.py +435 -0
  23. CodeFormer/basicsr/data/__init__.py +100 -0
  24. CodeFormer/basicsr/data/data_sampler.py +48 -0
  25. CodeFormer/basicsr/data/data_util.py +305 -0
  26. CodeFormer/basicsr/data/prefetch_dataloader.py +125 -0
  27. CodeFormer/basicsr/data/transforms.py +165 -0
  28. CodeFormer/basicsr/losses/__init__.py +26 -0
  29. CodeFormer/basicsr/losses/loss_util.py +95 -0
  30. CodeFormer/basicsr/losses/losses.py +455 -0
  31. CodeFormer/basicsr/metrics/__init__.py +19 -0
  32. CodeFormer/basicsr/metrics/metric_util.py +45 -0
  33. CodeFormer/basicsr/metrics/psnr_ssim.py +128 -0
  34. CodeFormer/basicsr/models/__init__.py +30 -0
  35. CodeFormer/basicsr/ops/__init__.py +0 -0
  36. CodeFormer/basicsr/ops/dcn/__init__.py +7 -0
  37. CodeFormer/basicsr/ops/dcn/deform_conv.py +377 -0
  38. CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda.cpp +685 -0
  39. CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu +867 -0
  40. CodeFormer/basicsr/ops/dcn/src/deform_conv_ext.cpp +164 -0
  41. CodeFormer/basicsr/ops/fused_act/__init__.py +3 -0
  42. CodeFormer/basicsr/ops/fused_act/fused_act.py +89 -0
  43. CodeFormer/basicsr/ops/fused_act/src/fused_bias_act.cpp +26 -0
  44. CodeFormer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu +100 -0
  45. CodeFormer/basicsr/ops/upfirdn2d/__init__.py +3 -0
  46. CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp +24 -0
  47. CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu +370 -0
  48. CodeFormer/basicsr/ops/upfirdn2d/upfirdn2d.py +186 -0
  49. CodeFormer/basicsr/setup.py +165 -0
  50. CodeFormer/basicsr/train.py +225 -0
CodeFormer DELETED
@@ -1 +0,0 @@
1
- Subproject commit c5b4593074ba6214284d6acd5f1719b6c5d739af
 
 
CodeFormer/.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode
2
+
3
+ # ignored files
4
+ version.py
5
+
6
+ # ignored files with suffix
7
+ *.html
8
+ # *.png
9
+ # *.jpeg
10
+ # *.jpg
11
+ *.pt
12
+ *.gif
13
+ *.pth
14
+ *.dat
15
+ *.zip
16
+
17
+ # template
18
+
19
+ # Byte-compiled / optimized / DLL files
20
+ __pycache__/
21
+ *.py[cod]
22
+ *$py.class
23
+
24
+ # C extensions
25
+ *.so
26
+
27
+ # Distribution / packaging
28
+ .Python
29
+ build/
30
+ develop-eggs/
31
+ dist/
32
+ downloads/
33
+ eggs/
34
+ .eggs/
35
+ lib/
36
+ lib64/
37
+ parts/
38
+ sdist/
39
+ var/
40
+ wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *.cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+
77
+ # Flask stuff:
78
+ instance/
79
+ .webassets-cache
80
+
81
+ # Scrapy stuff:
82
+ .scrapy
83
+
84
+ # Sphinx documentation
85
+ docs/_build/
86
+
87
+ # PyBuilder
88
+ target/
89
+
90
+ # Jupyter Notebook
91
+ .ipynb_checkpoints
92
+
93
+ # pyenv
94
+ .python-version
95
+
96
+ # celery beat schedule file
97
+ celerybeat-schedule
98
+
99
+ # SageMath parsed files
100
+ *.sage.py
101
+
102
+ # Environments
103
+ .env
104
+ .venv
105
+ env/
106
+ venv/
107
+ ENV/
108
+ env.bak/
109
+ venv.bak/
110
+
111
+ # Spyder project settings
112
+ .spyderproject
113
+ .spyproject
114
+
115
+ # Rope project settings
116
+ .ropeproject
117
+
118
+ # mkdocs documentation
119
+ /site
120
+
121
+ # mypy
122
+ .mypy_cache/
123
+
124
+ # project
125
+ results/
126
+ dlib/
127
+ *.pth
128
+ *_old*
129
+
CodeFormer/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/CodeFormer_logo.png" height=110>
3
+ </p>
4
+
5
+ ## Towards Robust Blind Face Restoration with Codebook Lookup Transformer
6
+
7
+ [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
8
+
9
+
10
+ <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) ![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
11
+
12
+ [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
13
+
14
+ S-Lab, Nanyang Technological University
15
+
16
+ <img src="assets/network.jpg" width="800px"/>
17
+
18
+
19
+ :star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
20
+
21
+ ### Update
22
+
23
+ - **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
24
+ - **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
25
+ - **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
26
+ - **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement.
27
+ - **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`.
28
+ - **2022.07.17**: Add Colab demo of CodeFormer. <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
29
+ - **2022.07.16**: Release inference code for face restoration. :blush:
30
+ - **2022.06.21**: This repo is created.
31
+
32
+ ### TODO
33
+ - [ ] Add checkpoint for face inpainting
34
+ - [ ] Add training code and config files
35
+ - [x] ~~Add background image enhancement~~
36
+
37
+ #### Face Restoration
38
+
39
+ <img src="assets/restoration_result1.png" width="400px"/> <img src="assets/restoration_result2.png" width="400px"/>
40
+ <img src="assets/restoration_result3.png" width="400px"/> <img src="assets/restoration_result4.png" width="400px"/>
41
+
42
+ #### Face Color Enhancement and Restoration
43
+
44
+ <img src="assets/color_enhancement_result1.png" width="400px"/> <img src="assets/color_enhancement_result2.png" width="400px"/>
45
+
46
+ #### Face Inpainting
47
+
48
+ <img src="assets/inpainting_result1.png" width="400px"/> <img src="assets/inpainting_result2.png" width="400px"/>
49
+
50
+
51
+
52
+ ### Dependencies and Installation
53
+
54
+ - Pytorch >= 1.7.1
55
+ - CUDA >= 10.1
56
+ - Other required packages in `requirements.txt`
57
+ ```
58
+ # git clone this repository
59
+ git clone https://github.com/sczhou/CodeFormer
60
+ cd CodeFormer
61
+
62
+ # create new anaconda env
63
+ conda create -n codeformer python=3.8 -y
64
+ conda activate codeformer
65
+
66
+ # install python dependencies
67
+ pip3 install -r requirements.txt
68
+ python basicsr/setup.py develop
69
+ ```
70
+ <!-- conda install -c conda-forge dlib -->
71
+
72
+ ### Quick Inference
73
+
74
+ ##### Download Pre-trained Models:
75
+ Download the facelib pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by runing the following command.
76
+ ```
77
+ python scripts/download_pretrained_models.py facelib
78
+ ```
79
+
80
+ Download the CodeFormer pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by runing the following command.
81
+ ```
82
+ python scripts/download_pretrained_models.py CodeFormer
83
+ ```
84
+
85
+ ##### Prepare Testing Data:
86
+ You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder.
87
+
88
+
89
+ ##### Testing on Face Restoration:
90
+ ```
91
+ # For cropped and aligned faces
92
+ python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
93
+
94
+ # For the whole images
95
+ # Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
96
+ # Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
97
+ python inference_codeformer.py --w 0.7 --test_path [input folder]
98
+ ```
99
+
100
+ NOTE that *w* is in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result.
101
+
102
+ The results will be saved in the `results` folder.
103
+
104
+ ### Citation
105
+ If our work is useful for your research, please consider citing:
106
+
107
+ @article{zhou2022codeformer,
108
+ author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
109
+ title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
110
+ journal = {arXiv preprint arXiv:2206.11253},
111
+ year = {2022}
112
+ }
113
+
114
+ ### License
115
+
116
+ <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.
117
+
118
+ ### Acknowledgement
119
+
120
+ This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). We also borrow some codes from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). Thanks for their awesome works.
121
+
122
+ ### Contact
123
+ If you have any question, please feel free to reach me out at `shangchenzhou@gmail.com`.
CodeFormer/assets/CodeFormer_logo.png ADDED
CodeFormer/assets/color_enhancement_result1.png ADDED
CodeFormer/assets/color_enhancement_result2.png ADDED
CodeFormer/assets/inpainting_result1.png ADDED
CodeFormer/assets/inpainting_result2.png ADDED
CodeFormer/assets/network.jpg ADDED
CodeFormer/assets/restoration_result1.png ADDED
CodeFormer/assets/restoration_result2.png ADDED
CodeFormer/assets/restoration_result3.png ADDED
CodeFormer/assets/restoration_result4.png ADDED
CodeFormer/basicsr/VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 1.3.2
CodeFormer/basicsr/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .archs import *
4
+ from .data import *
5
+ from .losses import *
6
+ from .metrics import *
7
+ from .models import *
8
+ from .ops import *
9
+ from .train import *
10
+ from .utils import *
11
+ from .version import __gitsha__, __version__
CodeFormer/basicsr/archs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with
12
+ # '_arch.py'
13
+ arch_folder = osp.dirname(osp.abspath(__file__))
14
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
+ # import all the arch modules
16
+ _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
17
+
18
+
19
+ def build_network(opt):
20
+ opt = deepcopy(opt)
21
+ network_type = opt.pop('type')
22
+ net = ARCH_REGISTRY.get(network_type)(**opt)
23
+ logger = get_root_logger()
24
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
25
+ return net
CodeFormer/basicsr/archs/arcface_arch.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from basicsr.utils.registry import ARCH_REGISTRY
3
+
4
+
5
+ def conv3x3(inplanes, outplanes, stride=1):
6
+ """A simple wrapper for 3x3 convolution with padding.
7
+
8
+ Args:
9
+ inplanes (int): Channel number of inputs.
10
+ outplanes (int): Channel number of outputs.
11
+ stride (int): Stride in convolution. Default: 1.
12
+ """
13
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14
+
15
+
16
+ class BasicBlock(nn.Module):
17
+ """Basic residual block used in the ResNetArcFace architecture.
18
+
19
+ Args:
20
+ inplanes (int): Channel number of inputs.
21
+ planes (int): Channel number of outputs.
22
+ stride (int): Stride in convolution. Default: 1.
23
+ downsample (nn.Module): The downsample module. Default: None.
24
+ """
25
+ expansion = 1 # output channel expansion ratio
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = nn.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = nn.BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class IRBlock(nn.Module):
57
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
+
59
+ Args:
60
+ inplanes (int): Channel number of inputs.
61
+ planes (int): Channel number of outputs.
62
+ stride (int): Stride in convolution. Default: 1.
63
+ downsample (nn.Module): The downsample module. Default: None.
64
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
+ """
66
+ expansion = 1 # output channel expansion ratio
67
+
68
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
+ super(IRBlock, self).__init__()
70
+ self.bn0 = nn.BatchNorm2d(inplanes)
71
+ self.conv1 = conv3x3(inplanes, inplanes)
72
+ self.bn1 = nn.BatchNorm2d(inplanes)
73
+ self.prelu = nn.PReLU()
74
+ self.conv2 = conv3x3(inplanes, planes, stride)
75
+ self.bn2 = nn.BatchNorm2d(planes)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+ self.use_se = use_se
79
+ if self.use_se:
80
+ self.se = SEBlock(planes)
81
+
82
+ def forward(self, x):
83
+ residual = x
84
+ out = self.bn0(x)
85
+ out = self.conv1(out)
86
+ out = self.bn1(out)
87
+ out = self.prelu(out)
88
+
89
+ out = self.conv2(out)
90
+ out = self.bn2(out)
91
+ if self.use_se:
92
+ out = self.se(out)
93
+
94
+ if self.downsample is not None:
95
+ residual = self.downsample(x)
96
+
97
+ out += residual
98
+ out = self.prelu(out)
99
+
100
+ return out
101
+
102
+
103
+ class Bottleneck(nn.Module):
104
+ """Bottleneck block used in the ResNetArcFace architecture.
105
+
106
+ Args:
107
+ inplanes (int): Channel number of inputs.
108
+ planes (int): Channel number of outputs.
109
+ stride (int): Stride in convolution. Default: 1.
110
+ downsample (nn.Module): The downsample module. Default: None.
111
+ """
112
+ expansion = 4 # output channel expansion ratio
113
+
114
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
115
+ super(Bottleneck, self).__init__()
116
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117
+ self.bn1 = nn.BatchNorm2d(planes)
118
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
119
+ self.bn2 = nn.BatchNorm2d(planes)
120
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
121
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
122
+ self.relu = nn.ReLU(inplace=True)
123
+ self.downsample = downsample
124
+ self.stride = stride
125
+
126
+ def forward(self, x):
127
+ residual = x
128
+
129
+ out = self.conv1(x)
130
+ out = self.bn1(out)
131
+ out = self.relu(out)
132
+
133
+ out = self.conv2(out)
134
+ out = self.bn2(out)
135
+ out = self.relu(out)
136
+
137
+ out = self.conv3(out)
138
+ out = self.bn3(out)
139
+
140
+ if self.downsample is not None:
141
+ residual = self.downsample(x)
142
+
143
+ out += residual
144
+ out = self.relu(out)
145
+
146
+ return out
147
+
148
+
149
+ class SEBlock(nn.Module):
150
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
+
152
+ Args:
153
+ channel (int): Channel number of inputs.
154
+ reduction (int): Channel reduction ration. Default: 16.
155
+ """
156
+
157
+ def __init__(self, channel, reduction=16):
158
+ super(SEBlock, self).__init__()
159
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
160
+ self.fc = nn.Sequential(
161
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
+ nn.Sigmoid())
163
+
164
+ def forward(self, x):
165
+ b, c, _, _ = x.size()
166
+ y = self.avg_pool(x).view(b, c)
167
+ y = self.fc(y).view(b, c, 1, 1)
168
+ return x * y
169
+
170
+
171
+ @ARCH_REGISTRY.register()
172
+ class ResNetArcFace(nn.Module):
173
+ """ArcFace with ResNet architectures.
174
+
175
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
+
177
+ Args:
178
+ block (str): Block used in the ArcFace architecture.
179
+ layers (tuple(int)): Block numbers in each layer.
180
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
+ """
182
+
183
+ def __init__(self, block, layers, use_se=True):
184
+ if block == 'IRBlock':
185
+ block = IRBlock
186
+ self.inplanes = 64
187
+ self.use_se = use_se
188
+ super(ResNetArcFace, self).__init__()
189
+
190
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
+ self.bn1 = nn.BatchNorm2d(64)
192
+ self.prelu = nn.PReLU()
193
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
194
+ self.layer1 = self._make_layer(block, 64, layers[0])
195
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
+ self.bn4 = nn.BatchNorm2d(512)
199
+ self.dropout = nn.Dropout()
200
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
+ self.bn5 = nn.BatchNorm1d(512)
202
+
203
+ # initialization
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.xavier_normal_(m.weight)
207
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
208
+ nn.init.constant_(m.weight, 1)
209
+ nn.init.constant_(m.bias, 0)
210
+ elif isinstance(m, nn.Linear):
211
+ nn.init.xavier_normal_(m.weight)
212
+ nn.init.constant_(m.bias, 0)
213
+
214
+ def _make_layer(self, block, planes, num_blocks, stride=1):
215
+ downsample = None
216
+ if stride != 1 or self.inplanes != planes * block.expansion:
217
+ downsample = nn.Sequential(
218
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
219
+ nn.BatchNorm2d(planes * block.expansion),
220
+ )
221
+ layers = []
222
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
+ self.inplanes = planes
224
+ for _ in range(1, num_blocks):
225
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
226
+
227
+ return nn.Sequential(*layers)
228
+
229
+ def forward(self, x):
230
+ x = self.conv1(x)
231
+ x = self.bn1(x)
232
+ x = self.prelu(x)
233
+ x = self.maxpool(x)
234
+
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.bn4(x)
240
+ x = self.dropout(x)
241
+ x = x.view(x.size(0), -1)
242
+ x = self.fc5(x)
243
+ x = self.bn5(x)
244
+
245
+ return x
CodeFormer/basicsr/archs/arch_util.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ from distutils.version import LooseVersion
7
+ from itertools import repeat
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init as init
11
+ from torch.nn.modules.batchnorm import _BatchNorm
12
+
13
+ from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
14
+ from basicsr.utils import get_root_logger
15
+
16
+
17
+ @torch.no_grad()
18
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
+ """Initialize network weights.
20
+
21
+ Args:
22
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
+ scale (float): Scale initialized weights, especially for residual
24
+ blocks. Default: 1.
25
+ bias_fill (float): The value to fill bias. Default: 0
26
+ kwargs (dict): Other arguments for initialization function.
27
+ """
28
+ if not isinstance(module_list, list):
29
+ module_list = [module_list]
30
+ for module in module_list:
31
+ for m in module.modules():
32
+ if isinstance(m, nn.Conv2d):
33
+ init.kaiming_normal_(m.weight, **kwargs)
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+ elif isinstance(m, nn.Linear):
38
+ init.kaiming_normal_(m.weight, **kwargs)
39
+ m.weight.data *= scale
40
+ if m.bias is not None:
41
+ m.bias.data.fill_(bias_fill)
42
+ elif isinstance(m, _BatchNorm):
43
+ init.constant_(m.weight, 1)
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+
47
+
48
+ def make_layer(basic_block, num_basic_block, **kwarg):
49
+ """Make layers by stacking the same blocks.
50
+
51
+ Args:
52
+ basic_block (nn.module): nn.module class for basic block.
53
+ num_basic_block (int): number of blocks.
54
+
55
+ Returns:
56
+ nn.Sequential: Stacked blocks in nn.Sequential.
57
+ """
58
+ layers = []
59
+ for _ in range(num_basic_block):
60
+ layers.append(basic_block(**kwarg))
61
+ return nn.Sequential(*layers)
62
+
63
+
64
+ class ResidualBlockNoBN(nn.Module):
65
+ """Residual block without BN.
66
+
67
+ It has a style of:
68
+ ---Conv-ReLU-Conv-+-
69
+ |________________|
70
+
71
+ Args:
72
+ num_feat (int): Channel number of intermediate features.
73
+ Default: 64.
74
+ res_scale (float): Residual scale. Default: 1.
75
+ pytorch_init (bool): If set to True, use pytorch default init,
76
+ otherwise, use default_init_weights. Default: False.
77
+ """
78
+
79
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
80
+ super(ResidualBlockNoBN, self).__init__()
81
+ self.res_scale = res_scale
82
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
83
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
+ self.relu = nn.ReLU(inplace=True)
85
+
86
+ if not pytorch_init:
87
+ default_init_weights([self.conv1, self.conv2], 0.1)
88
+
89
+ def forward(self, x):
90
+ identity = x
91
+ out = self.conv2(self.relu(self.conv1(x)))
92
+ return identity + out * self.res_scale
93
+
94
+
95
+ class Upsample(nn.Sequential):
96
+ """Upsample module.
97
+
98
+ Args:
99
+ scale (int): Scale factor. Supported scales: 2^n and 3.
100
+ num_feat (int): Channel number of intermediate features.
101
+ """
102
+
103
+ def __init__(self, scale, num_feat):
104
+ m = []
105
+ if (scale & (scale - 1)) == 0: # scale = 2^n
106
+ for _ in range(int(math.log(scale, 2))):
107
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
108
+ m.append(nn.PixelShuffle(2))
109
+ elif scale == 3:
110
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
111
+ m.append(nn.PixelShuffle(3))
112
+ else:
113
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
114
+ super(Upsample, self).__init__(*m)
115
+
116
+
117
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
118
+ """Warp an image or feature map with optical flow.
119
+
120
+ Args:
121
+ x (Tensor): Tensor with size (n, c, h, w).
122
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
123
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
124
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
125
+ Default: 'zeros'.
126
+ align_corners (bool): Before pytorch 1.3, the default value is
127
+ align_corners=True. After pytorch 1.3, the default value is
128
+ align_corners=False. Here, we use the True as default.
129
+
130
+ Returns:
131
+ Tensor: Warped image or feature map.
132
+ """
133
+ assert x.size()[-2:] == flow.size()[1:3]
134
+ _, _, h, w = x.size()
135
+ # create mesh grid
136
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
137
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138
+ grid.requires_grad = False
139
+
140
+ vgrid = grid + flow
141
+ # scale grid to [-1,1]
142
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
146
+
147
+ # TODO, what if align_corners=False
148
+ return output
149
+
150
+
151
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
152
+ """Resize a flow according to ratio or shape.
153
+
154
+ Args:
155
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
156
+ size_type (str): 'ratio' or 'shape'.
157
+ sizes (list[int | float]): the ratio for resizing or the final output
158
+ shape.
159
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
160
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
161
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
162
+ ratio > 1.0).
163
+ 2) The order of output_size should be [out_h, out_w].
164
+ interp_mode (str): The mode of interpolation for resizing.
165
+ Default: 'bilinear'.
166
+ align_corners (bool): Whether align corners. Default: False.
167
+
168
+ Returns:
169
+ Tensor: Resized flow.
170
+ """
171
+ _, _, flow_h, flow_w = flow.size()
172
+ if size_type == 'ratio':
173
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
174
+ elif size_type == 'shape':
175
+ output_h, output_w = sizes[0], sizes[1]
176
+ else:
177
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
178
+
179
+ input_flow = flow.clone()
180
+ ratio_h = output_h / flow_h
181
+ ratio_w = output_w / flow_w
182
+ input_flow[:, 0, :, :] *= ratio_w
183
+ input_flow[:, 1, :, :] *= ratio_h
184
+ resized_flow = F.interpolate(
185
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
186
+ return resized_flow
187
+
188
+
189
+ # TODO: may write a cpp file
190
+ def pixel_unshuffle(x, scale):
191
+ """ Pixel unshuffle.
192
+
193
+ Args:
194
+ x (Tensor): Input feature with shape (b, c, hh, hw).
195
+ scale (int): Downsample ratio.
196
+
197
+ Returns:
198
+ Tensor: the pixel unshuffled feature.
199
+ """
200
+ b, c, hh, hw = x.size()
201
+ out_channel = c * (scale**2)
202
+ assert hh % scale == 0 and hw % scale == 0
203
+ h = hh // scale
204
+ w = hw // scale
205
+ x_view = x.view(b, c, h, scale, w, scale)
206
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
207
+
208
+
209
+ class DCNv2Pack(ModulatedDeformConvPack):
210
+ """Modulated deformable conv for deformable alignment.
211
+
212
+ Different from the official DCNv2Pack, which generates offsets and masks
213
+ from the preceding features, this DCNv2Pack takes another different
214
+ features to generate offsets and masks.
215
+
216
+ Ref:
217
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
218
+ """
219
+
220
+ def forward(self, x, feat):
221
+ out = self.conv_offset(feat)
222
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
223
+ offset = torch.cat((o1, o2), dim=1)
224
+ mask = torch.sigmoid(mask)
225
+
226
+ offset_absmean = torch.mean(torch.abs(offset))
227
+ if offset_absmean > 50:
228
+ logger = get_root_logger()
229
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
230
+
231
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
232
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
233
+ self.dilation, mask)
234
+ else:
235
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
236
+ self.dilation, self.groups, self.deformable_groups)
237
+
238
+
239
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
240
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
241
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
242
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
243
+ def norm_cdf(x):
244
+ # Computes standard normal cumulative distribution function
245
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
246
+
247
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
248
+ warnings.warn(
249
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
250
+ 'The distribution of values may be incorrect.',
251
+ stacklevel=2)
252
+
253
+ with torch.no_grad():
254
+ # Values are generated by using a truncated uniform distribution and
255
+ # then using the inverse CDF for the normal distribution.
256
+ # Get upper and lower cdf values
257
+ low = norm_cdf((a - mean) / std)
258
+ up = norm_cdf((b - mean) / std)
259
+
260
+ # Uniformly fill tensor with values from [low, up], then translate to
261
+ # [2l-1, 2u-1].
262
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
263
+
264
+ # Use inverse cdf transform for normal distribution to get truncated
265
+ # standard normal
266
+ tensor.erfinv_()
267
+
268
+ # Transform to proper mean, std
269
+ tensor.mul_(std * math.sqrt(2.))
270
+ tensor.add_(mean)
271
+
272
+ # Clamp to ensure it's in the proper range
273
+ tensor.clamp_(min=a, max=b)
274
+ return tensor
275
+
276
+
277
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
278
+ r"""Fills the input Tensor with values drawn from a truncated
279
+ normal distribution.
280
+
281
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
282
+
283
+ The values are effectively drawn from the
284
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
285
+ with values outside :math:`[a, b]` redrawn until they are within
286
+ the bounds. The method used for generating the random values works
287
+ best when :math:`a \leq \text{mean} \leq b`.
288
+
289
+ Args:
290
+ tensor: an n-dimensional `torch.Tensor`
291
+ mean: the mean of the normal distribution
292
+ std: the standard deviation of the normal distribution
293
+ a: the minimum cutoff value
294
+ b: the maximum cutoff value
295
+
296
+ Examples:
297
+ >>> w = torch.empty(3, 5)
298
+ >>> nn.init.trunc_normal_(w)
299
+ """
300
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
301
+
302
+
303
+ # From PyTorch
304
+ def _ntuple(n):
305
+
306
+ def parse(x):
307
+ if isinstance(x, collections.abc.Iterable):
308
+ return x
309
+ return tuple(repeat(x, n))
310
+
311
+ return parse
312
+
313
+
314
+ to_1tuple = _ntuple(1)
315
+ to_2tuple = _ntuple(2)
316
+ to_3tuple = _ntuple(3)
317
+ to_4tuple = _ntuple(4)
318
+ to_ntuple = _ntuple
CodeFormer/basicsr/archs/codeformer_arch.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn, Tensor
5
+ import torch.nn.functional as F
6
+ from typing import Optional, List
7
+
8
+ from basicsr.archs.vqgan_arch import *
9
+ from basicsr.utils import get_root_logger
10
+ from basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ def calc_mean_std(feat, eps=1e-5):
13
+ """Calculate mean and std for adaptive_instance_normalization.
14
+
15
+ Args:
16
+ feat (Tensor): 4D tensor.
17
+ eps (float): A small value added to the variance to avoid
18
+ divide-by-zero. Default: 1e-5.
19
+ """
20
+ size = feat.size()
21
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
22
+ b, c = size[:2]
23
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
24
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
25
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
26
+ return feat_mean, feat_std
27
+
28
+
29
+ def adaptive_instance_normalization(content_feat, style_feat):
30
+ """Adaptive instance normalization.
31
+
32
+ Adjust the reference features to have the similar color and illuminations
33
+ as those in the degradate features.
34
+
35
+ Args:
36
+ content_feat (Tensor): The reference feature.
37
+ style_feat (Tensor): The degradate features.
38
+ """
39
+ size = content_feat.size()
40
+ style_mean, style_std = calc_mean_std(style_feat)
41
+ content_mean, content_std = calc_mean_std(content_feat)
42
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
43
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
44
+
45
+
46
+ class PositionEmbeddingSine(nn.Module):
47
+ """
48
+ This is a more standard version of the position embedding, very similar to the one
49
+ used by the Attention is all you need paper, generalized to work on images.
50
+ """
51
+
52
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
53
+ super().__init__()
54
+ self.num_pos_feats = num_pos_feats
55
+ self.temperature = temperature
56
+ self.normalize = normalize
57
+ if scale is not None and normalize is False:
58
+ raise ValueError("normalize should be True if scale is passed")
59
+ if scale is None:
60
+ scale = 2 * math.pi
61
+ self.scale = scale
62
+
63
+ def forward(self, x, mask=None):
64
+ if mask is None:
65
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
66
+ not_mask = ~mask
67
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
68
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
69
+ if self.normalize:
70
+ eps = 1e-6
71
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
72
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
73
+
74
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
75
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
76
+
77
+ pos_x = x_embed[:, :, :, None] / dim_t
78
+ pos_y = y_embed[:, :, :, None] / dim_t
79
+ pos_x = torch.stack(
80
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
81
+ ).flatten(3)
82
+ pos_y = torch.stack(
83
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
84
+ ).flatten(3)
85
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
86
+ return pos
87
+
88
+ def _get_activation_fn(activation):
89
+ """Return an activation function given a string"""
90
+ if activation == "relu":
91
+ return F.relu
92
+ if activation == "gelu":
93
+ return F.gelu
94
+ if activation == "glu":
95
+ return F.glu
96
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
97
+
98
+
99
+ class TransformerSALayer(nn.Module):
100
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
101
+ super().__init__()
102
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
103
+ # Implementation of Feedforward model - MLP
104
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
105
+ self.dropout = nn.Dropout(dropout)
106
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
107
+
108
+ self.norm1 = nn.LayerNorm(embed_dim)
109
+ self.norm2 = nn.LayerNorm(embed_dim)
110
+ self.dropout1 = nn.Dropout(dropout)
111
+ self.dropout2 = nn.Dropout(dropout)
112
+
113
+ self.activation = _get_activation_fn(activation)
114
+
115
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
116
+ return tensor if pos is None else tensor + pos
117
+
118
+ def forward(self, tgt,
119
+ tgt_mask: Optional[Tensor] = None,
120
+ tgt_key_padding_mask: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None):
122
+
123
+ # self attention
124
+ tgt2 = self.norm1(tgt)
125
+ q = k = self.with_pos_embed(tgt2, query_pos)
126
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
127
+ key_padding_mask=tgt_key_padding_mask)[0]
128
+ tgt = tgt + self.dropout1(tgt2)
129
+
130
+ # ffn
131
+ tgt2 = self.norm2(tgt)
132
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
133
+ tgt = tgt + self.dropout2(tgt2)
134
+ return tgt
135
+
136
+ class Fuse_sft_block(nn.Module):
137
+ def __init__(self, in_ch, out_ch):
138
+ super().__init__()
139
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
140
+
141
+ self.scale = nn.Sequential(
142
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
143
+ nn.LeakyReLU(0.2, True),
144
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
145
+
146
+ self.shift = nn.Sequential(
147
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
148
+ nn.LeakyReLU(0.2, True),
149
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
150
+
151
+ def forward(self, enc_feat, dec_feat, w=1):
152
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
153
+ scale = self.scale(enc_feat)
154
+ shift = self.shift(enc_feat)
155
+ residual = w * (dec_feat * scale + shift)
156
+ out = dec_feat + residual
157
+ return out
158
+
159
+
160
+ @ARCH_REGISTRY.register()
161
+ class CodeFormer(VQAutoEncoder):
162
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163
+ codebook_size=1024, latent_size=256,
164
+ connect_list=['32', '64', '128', '256'],
165
+ fix_modules=['quantize','generator']):
166
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
167
+
168
+ if fix_modules is not None:
169
+ for module in fix_modules:
170
+ for param in getattr(self, module).parameters():
171
+ param.requires_grad = False
172
+
173
+ self.connect_list = connect_list
174
+ self.n_layers = n_layers
175
+ self.dim_embd = dim_embd
176
+ self.dim_mlp = dim_embd*2
177
+
178
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
179
+ self.feat_emb = nn.Linear(256, self.dim_embd)
180
+
181
+ # transformer
182
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
183
+ for _ in range(self.n_layers)])
184
+
185
+ # logits_predict head
186
+ self.idx_pred_layer = nn.Sequential(
187
+ nn.LayerNorm(dim_embd),
188
+ nn.Linear(dim_embd, codebook_size, bias=False))
189
+
190
+ self.channels = {
191
+ '16': 512,
192
+ '32': 256,
193
+ '64': 256,
194
+ '128': 128,
195
+ '256': 128,
196
+ '512': 64,
197
+ }
198
+
199
+ # after second residual block for > 16, before attn layer for ==16
200
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
201
+ # after first residual block for > 16, before attn layer for ==16
202
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
203
+
204
+ # fuse_convs_dict
205
+ self.fuse_convs_dict = nn.ModuleDict()
206
+ for f_size in self.connect_list:
207
+ in_ch = self.channels[f_size]
208
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
209
+
210
+ def _init_weights(self, module):
211
+ if isinstance(module, (nn.Linear, nn.Embedding)):
212
+ module.weight.data.normal_(mean=0.0, std=0.02)
213
+ if isinstance(module, nn.Linear) and module.bias is not None:
214
+ module.bias.data.zero_()
215
+ elif isinstance(module, nn.LayerNorm):
216
+ module.bias.data.zero_()
217
+ module.weight.data.fill_(1.0)
218
+
219
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
220
+ # ################### Encoder #####################
221
+ enc_feat_dict = {}
222
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
223
+ for i, block in enumerate(self.encoder.blocks):
224
+ x = block(x)
225
+ if i in out_list:
226
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
227
+
228
+ lq_feat = x
229
+ # ################# Transformer ###################
230
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
231
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
232
+ # BCHW -> BC(HW) -> (HW)BC
233
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
234
+ query_emb = feat_emb
235
+ # Transformer encoder
236
+ for layer in self.ft_layers:
237
+ query_emb = layer(query_emb, query_pos=pos_emb)
238
+
239
+ # output logits
240
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
241
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
242
+
243
+ if code_only: # for training stage II
244
+ # logits doesn't need softmax before cross_entropy loss
245
+ return logits, lq_feat
246
+
247
+ # ################# Quantization ###################
248
+ # if self.training:
249
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
250
+ # # b(hw)c -> bc(hw) -> bchw
251
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
252
+ # ------------
253
+ soft_one_hot = F.softmax(logits, dim=2)
254
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
255
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
256
+ # preserve gradients
257
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
258
+
259
+ if detach_16:
260
+ quant_feat = quant_feat.detach() # for training stage III
261
+ if adain:
262
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
263
+
264
+ # ################## Generator ####################
265
+ x = quant_feat
266
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
267
+
268
+ for i, block in enumerate(self.generator.blocks):
269
+ x = block(x)
270
+ if i in fuse_list: # fuse after i-th block
271
+ f_size = str(x.shape[-1])
272
+ if w>0:
273
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
274
+ out = x
275
+ # logits doesn't need softmax before cross_entropy loss
276
+ return out, logits, lq_feat
CodeFormer/basicsr/archs/rrdbnet_arch.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import default_init_weights, make_layer, pixel_unshuffle
7
+
8
+
9
+ class ResidualDenseBlock(nn.Module):
10
+ """Residual Dense Block.
11
+
12
+ Used in RRDB block in ESRGAN.
13
+
14
+ Args:
15
+ num_feat (int): Channel number of intermediate features.
16
+ num_grow_ch (int): Channels for each growth.
17
+ """
18
+
19
+ def __init__(self, num_feat=64, num_grow_ch=32):
20
+ super(ResidualDenseBlock, self).__init__()
21
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
22
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
25
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
26
+
27
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
+
29
+ # initialization
30
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
31
+
32
+ def forward(self, x):
33
+ x1 = self.lrelu(self.conv1(x))
34
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
35
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
36
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
37
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
38
+ # Emperically, we use 0.2 to scale the residual for better performance
39
+ return x5 * 0.2 + x
40
+
41
+
42
+ class RRDB(nn.Module):
43
+ """Residual in Residual Dense Block.
44
+
45
+ Used in RRDB-Net in ESRGAN.
46
+
47
+ Args:
48
+ num_feat (int): Channel number of intermediate features.
49
+ num_grow_ch (int): Channels for each growth.
50
+ """
51
+
52
+ def __init__(self, num_feat, num_grow_ch=32):
53
+ super(RRDB, self).__init__()
54
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
57
+
58
+ def forward(self, x):
59
+ out = self.rdb1(x)
60
+ out = self.rdb2(out)
61
+ out = self.rdb3(out)
62
+ # Emperically, we use 0.2 to scale the residual for better performance
63
+ return out * 0.2 + x
64
+
65
+
66
+ @ARCH_REGISTRY.register()
67
+ class RRDBNet(nn.Module):
68
+ """Networks consisting of Residual in Residual Dense Block, which is used
69
+ in ESRGAN.
70
+
71
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
72
+
73
+ We extend ESRGAN for scale x2 and scale x1.
74
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
75
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
76
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
77
+
78
+ Args:
79
+ num_in_ch (int): Channel number of inputs.
80
+ num_out_ch (int): Channel number of outputs.
81
+ num_feat (int): Channel number of intermediate features.
82
+ Default: 64
83
+ num_block (int): Block number in the trunk network. Defaults: 23
84
+ num_grow_ch (int): Channels for each growth. Default: 32.
85
+ """
86
+
87
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
88
+ super(RRDBNet, self).__init__()
89
+ self.scale = scale
90
+ if scale == 2:
91
+ num_in_ch = num_in_ch * 4
92
+ elif scale == 1:
93
+ num_in_ch = num_in_ch * 16
94
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
95
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
96
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ # upsample
98
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
119
+ return out
CodeFormer/basicsr/archs/vgg_arch.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+ VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10
+ NAMES = {
11
+ 'vgg11': [
12
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14
+ 'pool5'
15
+ ],
16
+ 'vgg13': [
17
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20
+ ],
21
+ 'vgg16': [
22
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25
+ 'pool5'
26
+ ],
27
+ 'vgg19': [
28
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32
+ ]
33
+ }
34
+
35
+
36
+ def insert_bn(names):
37
+ """Insert bn layer after each conv.
38
+
39
+ Args:
40
+ names (list): The list of layer names.
41
+
42
+ Returns:
43
+ list: The list of layer names with bn layers.
44
+ """
45
+ names_bn = []
46
+ for name in names:
47
+ names_bn.append(name)
48
+ if 'conv' in name:
49
+ position = name.replace('conv', '')
50
+ names_bn.append('bn' + position)
51
+ return names_bn
52
+
53
+
54
+ @ARCH_REGISTRY.register()
55
+ class VGGFeatureExtractor(nn.Module):
56
+ """VGG network for feature extraction.
57
+
58
+ In this implementation, we allow users to choose whether use normalization
59
+ in the input feature and the type of vgg network. Note that the pretrained
60
+ path must fit the vgg type.
61
+
62
+ Args:
63
+ layer_name_list (list[str]): Forward function returns the corresponding
64
+ features according to the layer_name_list.
65
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67
+ use_input_norm (bool): If True, normalize the input image. Importantly,
68
+ the input feature must in the range [0, 1]. Default: True.
69
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70
+ Default: False.
71
+ requires_grad (bool): If true, the parameters of VGG network will be
72
+ optimized. Default: False.
73
+ remove_pooling (bool): If true, the max pooling operations in VGG net
74
+ will be removed. Default: False.
75
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
76
+ """
77
+
78
+ def __init__(self,
79
+ layer_name_list,
80
+ vgg_type='vgg19',
81
+ use_input_norm=True,
82
+ range_norm=False,
83
+ requires_grad=False,
84
+ remove_pooling=False,
85
+ pooling_stride=2):
86
+ super(VGGFeatureExtractor, self).__init__()
87
+
88
+ self.layer_name_list = layer_name_list
89
+ self.use_input_norm = use_input_norm
90
+ self.range_norm = range_norm
91
+
92
+ self.names = NAMES[vgg_type.replace('_bn', '')]
93
+ if 'bn' in vgg_type:
94
+ self.names = insert_bn(self.names)
95
+
96
+ # only borrow layers that will be used to avoid unused params
97
+ max_idx = 0
98
+ for v in layer_name_list:
99
+ idx = self.names.index(v)
100
+ if idx > max_idx:
101
+ max_idx = idx
102
+
103
+ if os.path.exists(VGG_PRETRAIN_PATH):
104
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106
+ vgg_net.load_state_dict(state_dict)
107
+ else:
108
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109
+
110
+ features = vgg_net.features[:max_idx + 1]
111
+
112
+ modified_net = OrderedDict()
113
+ for k, v in zip(self.names, features):
114
+ if 'pool' in k:
115
+ # if remove_pooling is true, pooling operation will be removed
116
+ if remove_pooling:
117
+ continue
118
+ else:
119
+ # in some cases, we may want to change the default stride
120
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121
+ else:
122
+ modified_net[k] = v
123
+
124
+ self.vgg_net = nn.Sequential(modified_net)
125
+
126
+ if not requires_grad:
127
+ self.vgg_net.eval()
128
+ for param in self.parameters():
129
+ param.requires_grad = False
130
+ else:
131
+ self.vgg_net.train()
132
+ for param in self.parameters():
133
+ param.requires_grad = True
134
+
135
+ if self.use_input_norm:
136
+ # the mean is for image with range [0, 1]
137
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138
+ # the std is for image with range [0, 1]
139
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140
+
141
+ def forward(self, x):
142
+ """Forward function.
143
+
144
+ Args:
145
+ x (Tensor): Input tensor with shape (n, c, h, w).
146
+
147
+ Returns:
148
+ Tensor: Forward results.
149
+ """
150
+ if self.range_norm:
151
+ x = (x + 1) / 2
152
+ if self.use_input_norm:
153
+ x = (x - self.mean) / self.std
154
+ output = {}
155
+
156
+ for key, layer in self.vgg_net._modules.items():
157
+ x = layer(x)
158
+ if key in self.layer_name_list:
159
+ output[key] = x.clone()
160
+
161
+ return output
CodeFormer/basicsr/archs/vqgan_arch.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
3
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
4
+
5
+ '''
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import copy
11
+ from basicsr.utils import get_root_logger
12
+ from basicsr.utils.registry import ARCH_REGISTRY
13
+
14
+ def normalize(in_channels):
15
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
16
+
17
+
18
+ @torch.jit.script
19
+ def swish(x):
20
+ return x*torch.sigmoid(x)
21
+
22
+
23
+ # Define VQVAE classes
24
+ class VectorQuantizer(nn.Module):
25
+ def __init__(self, codebook_size, emb_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.codebook_size = codebook_size # number of embeddings
28
+ self.emb_dim = emb_dim # dimension of embedding
29
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
30
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
31
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
32
+
33
+ def forward(self, z):
34
+ # reshape z -> (batch, height, width, channel) and flatten
35
+ z = z.permute(0, 2, 3, 1).contiguous()
36
+ z_flattened = z.view(-1, self.emb_dim)
37
+
38
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
39
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
40
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
41
+
42
+ mean_distance = torch.mean(d)
43
+ # find closest encodings
44
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
45
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
46
+ # [0-1], higher score, higher confidence
47
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
48
+
49
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
50
+ min_encodings.scatter_(1, min_encoding_indices, 1)
51
+
52
+ # get quantized latent vectors
53
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
54
+ # compute loss for embedding
55
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
56
+ # preserve gradients
57
+ z_q = z + (z_q - z).detach()
58
+
59
+ # perplexity
60
+ e_mean = torch.mean(min_encodings, dim=0)
61
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
62
+ # reshape back to match original input shape
63
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
64
+
65
+ return z_q, loss, {
66
+ "perplexity": perplexity,
67
+ "min_encodings": min_encodings,
68
+ "min_encoding_indices": min_encoding_indices,
69
+ "min_encoding_scores": min_encoding_scores,
70
+ "mean_distance": mean_distance
71
+ }
72
+
73
+ def get_codebook_feat(self, indices, shape):
74
+ # input indices: batch*token_num -> (batch*token_num)*1
75
+ # shape: batch, height, width, channel
76
+ indices = indices.view(-1,1)
77
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
78
+ min_encodings.scatter_(1, indices, 1)
79
+ # get quantized latent vectors
80
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
81
+
82
+ if shape is not None: # reshape back to match original input shape
83
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
84
+
85
+ return z_q
86
+
87
+
88
+ class GumbelQuantizer(nn.Module):
89
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
90
+ super().__init__()
91
+ self.codebook_size = codebook_size # number of embeddings
92
+ self.emb_dim = emb_dim # dimension of embedding
93
+ self.straight_through = straight_through
94
+ self.temperature = temp_init
95
+ self.kl_weight = kl_weight
96
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
97
+ self.embed = nn.Embedding(codebook_size, emb_dim)
98
+
99
+ def forward(self, z):
100
+ hard = self.straight_through if self.training else True
101
+
102
+ logits = self.proj(z)
103
+
104
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
105
+
106
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
107
+
108
+ # + kl divergence to the prior loss
109
+ qy = F.softmax(logits, dim=1)
110
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
111
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
112
+
113
+ return z_q, diff, {
114
+ "min_encoding_indices": min_encoding_indices
115
+ }
116
+
117
+
118
+ class Downsample(nn.Module):
119
+ def __init__(self, in_channels):
120
+ super().__init__()
121
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
122
+
123
+ def forward(self, x):
124
+ pad = (0, 1, 0, 1)
125
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
126
+ x = self.conv(x)
127
+ return x
128
+
129
+
130
+ class Upsample(nn.Module):
131
+ def __init__(self, in_channels):
132
+ super().__init__()
133
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
134
+
135
+ def forward(self, x):
136
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
137
+ x = self.conv(x)
138
+
139
+ return x
140
+
141
+
142
+ class ResBlock(nn.Module):
143
+ def __init__(self, in_channels, out_channels=None):
144
+ super(ResBlock, self).__init__()
145
+ self.in_channels = in_channels
146
+ self.out_channels = in_channels if out_channels is None else out_channels
147
+ self.norm1 = normalize(in_channels)
148
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
149
+ self.norm2 = normalize(out_channels)
150
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
151
+ if self.in_channels != self.out_channels:
152
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
153
+
154
+ def forward(self, x_in):
155
+ x = x_in
156
+ x = self.norm1(x)
157
+ x = swish(x)
158
+ x = self.conv1(x)
159
+ x = self.norm2(x)
160
+ x = swish(x)
161
+ x = self.conv2(x)
162
+ if self.in_channels != self.out_channels:
163
+ x_in = self.conv_out(x_in)
164
+
165
+ return x + x_in
166
+
167
+
168
+ class AttnBlock(nn.Module):
169
+ def __init__(self, in_channels):
170
+ super().__init__()
171
+ self.in_channels = in_channels
172
+
173
+ self.norm = normalize(in_channels)
174
+ self.q = torch.nn.Conv2d(
175
+ in_channels,
176
+ in_channels,
177
+ kernel_size=1,
178
+ stride=1,
179
+ padding=0
180
+ )
181
+ self.k = torch.nn.Conv2d(
182
+ in_channels,
183
+ in_channels,
184
+ kernel_size=1,
185
+ stride=1,
186
+ padding=0
187
+ )
188
+ self.v = torch.nn.Conv2d(
189
+ in_channels,
190
+ in_channels,
191
+ kernel_size=1,
192
+ stride=1,
193
+ padding=0
194
+ )
195
+ self.proj_out = torch.nn.Conv2d(
196
+ in_channels,
197
+ in_channels,
198
+ kernel_size=1,
199
+ stride=1,
200
+ padding=0
201
+ )
202
+
203
+ def forward(self, x):
204
+ h_ = x
205
+ h_ = self.norm(h_)
206
+ q = self.q(h_)
207
+ k = self.k(h_)
208
+ v = self.v(h_)
209
+
210
+ # compute attention
211
+ b, c, h, w = q.shape
212
+ q = q.reshape(b, c, h*w)
213
+ q = q.permute(0, 2, 1)
214
+ k = k.reshape(b, c, h*w)
215
+ w_ = torch.bmm(q, k)
216
+ w_ = w_ * (int(c)**(-0.5))
217
+ w_ = F.softmax(w_, dim=2)
218
+
219
+ # attend to values
220
+ v = v.reshape(b, c, h*w)
221
+ w_ = w_.permute(0, 2, 1)
222
+ h_ = torch.bmm(v, w_)
223
+ h_ = h_.reshape(b, c, h, w)
224
+
225
+ h_ = self.proj_out(h_)
226
+
227
+ return x+h_
228
+
229
+
230
+ class Encoder(nn.Module):
231
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
232
+ super().__init__()
233
+ self.nf = nf
234
+ self.num_resolutions = len(ch_mult)
235
+ self.num_res_blocks = num_res_blocks
236
+ self.resolution = resolution
237
+ self.attn_resolutions = attn_resolutions
238
+
239
+ curr_res = self.resolution
240
+ in_ch_mult = (1,)+tuple(ch_mult)
241
+
242
+ blocks = []
243
+ # initial convultion
244
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
245
+
246
+ # residual and downsampling blocks, with attention on smaller res (16x16)
247
+ for i in range(self.num_resolutions):
248
+ block_in_ch = nf * in_ch_mult[i]
249
+ block_out_ch = nf * ch_mult[i]
250
+ for _ in range(self.num_res_blocks):
251
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
252
+ block_in_ch = block_out_ch
253
+ if curr_res in attn_resolutions:
254
+ blocks.append(AttnBlock(block_in_ch))
255
+
256
+ if i != self.num_resolutions - 1:
257
+ blocks.append(Downsample(block_in_ch))
258
+ curr_res = curr_res // 2
259
+
260
+ # non-local attention block
261
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
262
+ blocks.append(AttnBlock(block_in_ch))
263
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
264
+
265
+ # normalise and convert to latent size
266
+ blocks.append(normalize(block_in_ch))
267
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
268
+ self.blocks = nn.ModuleList(blocks)
269
+
270
+ def forward(self, x):
271
+ for block in self.blocks:
272
+ x = block(x)
273
+
274
+ return x
275
+
276
+
277
+ class Generator(nn.Module):
278
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
279
+ super().__init__()
280
+ self.nf = nf
281
+ self.ch_mult = ch_mult
282
+ self.num_resolutions = len(self.ch_mult)
283
+ self.num_res_blocks = res_blocks
284
+ self.resolution = img_size
285
+ self.attn_resolutions = attn_resolutions
286
+ self.in_channels = emb_dim
287
+ self.out_channels = 3
288
+ block_in_ch = self.nf * self.ch_mult[-1]
289
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
290
+
291
+ blocks = []
292
+ # initial conv
293
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
294
+
295
+ # non-local attention block
296
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
297
+ blocks.append(AttnBlock(block_in_ch))
298
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
299
+
300
+ for i in reversed(range(self.num_resolutions)):
301
+ block_out_ch = self.nf * self.ch_mult[i]
302
+
303
+ for _ in range(self.num_res_blocks):
304
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
305
+ block_in_ch = block_out_ch
306
+
307
+ if curr_res in self.attn_resolutions:
308
+ blocks.append(AttnBlock(block_in_ch))
309
+
310
+ if i != 0:
311
+ blocks.append(Upsample(block_in_ch))
312
+ curr_res = curr_res * 2
313
+
314
+ blocks.append(normalize(block_in_ch))
315
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
316
+
317
+ self.blocks = nn.ModuleList(blocks)
318
+
319
+
320
+ def forward(self, x):
321
+ for block in self.blocks:
322
+ x = block(x)
323
+
324
+ return x
325
+
326
+
327
+ @ARCH_REGISTRY.register()
328
+ class VQAutoEncoder(nn.Module):
329
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
330
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
331
+ super().__init__()
332
+ logger = get_root_logger()
333
+ self.in_channels = 3
334
+ self.nf = nf
335
+ self.n_blocks = res_blocks
336
+ self.codebook_size = codebook_size
337
+ self.embed_dim = emb_dim
338
+ self.ch_mult = ch_mult
339
+ self.resolution = img_size
340
+ self.attn_resolutions = attn_resolutions
341
+ self.quantizer_type = quantizer
342
+ self.encoder = Encoder(
343
+ self.in_channels,
344
+ self.nf,
345
+ self.embed_dim,
346
+ self.ch_mult,
347
+ self.n_blocks,
348
+ self.resolution,
349
+ self.attn_resolutions
350
+ )
351
+ if self.quantizer_type == "nearest":
352
+ self.beta = beta #0.25
353
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
354
+ elif self.quantizer_type == "gumbel":
355
+ self.gumbel_num_hiddens = emb_dim
356
+ self.straight_through = gumbel_straight_through
357
+ self.kl_weight = gumbel_kl_weight
358
+ self.quantize = GumbelQuantizer(
359
+ self.codebook_size,
360
+ self.embed_dim,
361
+ self.gumbel_num_hiddens,
362
+ self.straight_through,
363
+ self.kl_weight
364
+ )
365
+ self.generator = Generator(
366
+ self.nf,
367
+ self.embed_dim,
368
+ self.ch_mult,
369
+ self.n_blocks,
370
+ self.resolution,
371
+ self.attn_resolutions
372
+ )
373
+
374
+ if model_path is not None:
375
+ chkpt = torch.load(model_path, map_location='cpu')
376
+ if 'params_ema' in chkpt:
377
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
378
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
379
+ elif 'params' in chkpt:
380
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
381
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
382
+ else:
383
+ raise ValueError(f'Wrong params!')
384
+
385
+
386
+ def forward(self, x):
387
+ x = self.encoder(x)
388
+ quant, codebook_loss, quant_stats = self.quantize(x)
389
+ x = self.generator(quant)
390
+ return x, codebook_loss, quant_stats
391
+
392
+
393
+
394
+ # patch based discriminator
395
+ @ARCH_REGISTRY.register()
396
+ class VQGANDiscriminator(nn.Module):
397
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
398
+ super().__init__()
399
+
400
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
401
+ ndf_mult = 1
402
+ ndf_mult_prev = 1
403
+ for n in range(1, n_layers): # gradually increase the number of filters
404
+ ndf_mult_prev = ndf_mult
405
+ ndf_mult = min(2 ** n, 8)
406
+ layers += [
407
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
408
+ nn.BatchNorm2d(ndf * ndf_mult),
409
+ nn.LeakyReLU(0.2, True)
410
+ ]
411
+
412
+ ndf_mult_prev = ndf_mult
413
+ ndf_mult = min(2 ** n_layers, 8)
414
+
415
+ layers += [
416
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
417
+ nn.BatchNorm2d(ndf * ndf_mult),
418
+ nn.LeakyReLU(0.2, True)
419
+ ]
420
+
421
+ layers += [
422
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
423
+ self.main = nn.Sequential(*layers)
424
+
425
+ if model_path is not None:
426
+ chkpt = torch.load(model_path, map_location='cpu')
427
+ if 'params_d' in chkpt:
428
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
429
+ elif 'params' in chkpt:
430
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
431
+ else:
432
+ raise ValueError(f'Wrong params!')
433
+
434
+ def forward(self, x):
435
+ return self.main(x)
CodeFormer/basicsr/data/__init__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from basicsr.utils import get_root_logger, scandir
12
+ from basicsr.utils.dist_util import get_dist_info
13
+ from basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must constain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+
84
+ prefetch_mode = dataset_opt.get('prefetch_mode')
85
+ if prefetch_mode == 'cpu': # CPUPrefetcher
86
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
87
+ logger = get_root_logger()
88
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
89
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
90
+ else:
91
+ # prefetch_mode=None: Normal dataloader
92
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
93
+ return torch.utils.data.DataLoader(**dataloader_args)
94
+
95
+
96
+ def worker_init_fn(worker_id, num_workers, rank, seed):
97
+ # Set the worker seed to num_workers * rank + worker_id + seed
98
+ worker_seed = num_workers * rank + worker_id + seed
99
+ np.random.seed(worker_seed)
100
+ random.seed(worker_seed)
CodeFormer/basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
CodeFormer/basicsr/data/data_util.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from os import path as osp
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.data.transforms import mod_crop
8
+ from basicsr.utils import img2tensor, scandir
9
+
10
+
11
+ def read_img_seq(path, require_mod_crop=False, scale=1):
12
+ """Read a sequence of images from a given folder path.
13
+
14
+ Args:
15
+ path (list[str] | str): List of image paths or image folder path.
16
+ require_mod_crop (bool): Require mod crop for each image.
17
+ Default: False.
18
+ scale (int): Scale factor for mod_crop. Default: 1.
19
+
20
+ Returns:
21
+ Tensor: size (t, c, h, w), RGB, [0, 1].
22
+ """
23
+ if isinstance(path, list):
24
+ img_paths = path
25
+ else:
26
+ img_paths = sorted(list(scandir(path, full_path=True)))
27
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
28
+ if require_mod_crop:
29
+ imgs = [mod_crop(img, scale) for img in imgs]
30
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
31
+ imgs = torch.stack(imgs, dim=0)
32
+ return imgs
33
+
34
+
35
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
36
+ """Generate an index list for reading `num_frames` frames from a sequence
37
+ of images.
38
+
39
+ Args:
40
+ crt_idx (int): Current center index.
41
+ max_frame_num (int): Max number of the sequence of images (from 1).
42
+ num_frames (int): Reading num_frames frames.
43
+ padding (str): Padding mode, one of
44
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
45
+ Examples: current_idx = 0, num_frames = 5
46
+ The generated frame indices under different padding mode:
47
+ replicate: [0, 0, 0, 1, 2]
48
+ reflection: [2, 1, 0, 1, 2]
49
+ reflection_circle: [4, 3, 0, 1, 2]
50
+ circle: [3, 4, 0, 1, 2]
51
+
52
+ Returns:
53
+ list[int]: A list of indices.
54
+ """
55
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
56
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
57
+
58
+ max_frame_num = max_frame_num - 1 # start from 0
59
+ num_pad = num_frames // 2
60
+
61
+ indices = []
62
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
63
+ if i < 0:
64
+ if padding == 'replicate':
65
+ pad_idx = 0
66
+ elif padding == 'reflection':
67
+ pad_idx = -i
68
+ elif padding == 'reflection_circle':
69
+ pad_idx = crt_idx + num_pad - i
70
+ else:
71
+ pad_idx = num_frames + i
72
+ elif i > max_frame_num:
73
+ if padding == 'replicate':
74
+ pad_idx = max_frame_num
75
+ elif padding == 'reflection':
76
+ pad_idx = max_frame_num * 2 - i
77
+ elif padding == 'reflection_circle':
78
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
79
+ else:
80
+ pad_idx = i - num_frames
81
+ else:
82
+ pad_idx = i
83
+ indices.append(pad_idx)
84
+ return indices
85
+
86
+
87
+ def paired_paths_from_lmdb(folders, keys):
88
+ """Generate paired paths from lmdb files.
89
+
90
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
91
+
92
+ lq.lmdb
93
+ ├── data.mdb
94
+ ├── lock.mdb
95
+ ├── meta_info.txt
96
+
97
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
98
+ https://lmdb.readthedocs.io/en/release/ for more details.
99
+
100
+ The meta_info.txt is a specified txt file to record the meta information
101
+ of our datasets. It will be automatically created when preparing
102
+ datasets by our provided dataset tools.
103
+ Each line in the txt file records
104
+ 1)image name (with extension),
105
+ 2)image shape,
106
+ 3)compression level, separated by a white space.
107
+ Example: `baboon.png (120,125,3) 1`
108
+
109
+ We use the image name without extension as the lmdb key.
110
+ Note that we use the same key for the corresponding lq and gt images.
111
+
112
+ Args:
113
+ folders (list[str]): A list of folder path. The order of list should
114
+ be [input_folder, gt_folder].
115
+ keys (list[str]): A list of keys identifying folders. The order should
116
+ be in consistent with folders, e.g., ['lq', 'gt'].
117
+ Note that this key is different from lmdb keys.
118
+
119
+ Returns:
120
+ list[str]: Returned path list.
121
+ """
122
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
123
+ f'But got {len(folders)}')
124
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
125
+ input_folder, gt_folder = folders
126
+ input_key, gt_key = keys
127
+
128
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
129
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
130
+ f'formats. But received {input_key}: {input_folder}; '
131
+ f'{gt_key}: {gt_folder}')
132
+ # ensure that the two meta_info files are the same
133
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
134
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
135
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
136
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
137
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
138
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
139
+ else:
140
+ paths = []
141
+ for lmdb_key in sorted(input_lmdb_keys):
142
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
143
+ return paths
144
+
145
+
146
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
147
+ """Generate paired paths from an meta information file.
148
+
149
+ Each line in the meta information file contains the image names and
150
+ image shape (usually for gt), separated by a white space.
151
+
152
+ Example of an meta information file:
153
+ ```
154
+ 0001_s001.png (480,480,3)
155
+ 0001_s002.png (480,480,3)
156
+ ```
157
+
158
+ Args:
159
+ folders (list[str]): A list of folder path. The order of list should
160
+ be [input_folder, gt_folder].
161
+ keys (list[str]): A list of keys identifying folders. The order should
162
+ be in consistent with folders, e.g., ['lq', 'gt'].
163
+ meta_info_file (str): Path to the meta information file.
164
+ filename_tmpl (str): Template for each filename. Note that the
165
+ template excludes the file extension. Usually the filename_tmpl is
166
+ for files in the input folder.
167
+
168
+ Returns:
169
+ list[str]: Returned path list.
170
+ """
171
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
172
+ f'But got {len(folders)}')
173
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
174
+ input_folder, gt_folder = folders
175
+ input_key, gt_key = keys
176
+
177
+ with open(meta_info_file, 'r') as fin:
178
+ gt_names = [line.split(' ')[0] for line in fin]
179
+
180
+ paths = []
181
+ for gt_name in gt_names:
182
+ basename, ext = osp.splitext(osp.basename(gt_name))
183
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
184
+ input_path = osp.join(input_folder, input_name)
185
+ gt_path = osp.join(gt_folder, gt_name)
186
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
187
+ return paths
188
+
189
+
190
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
191
+ """Generate paired paths from folders.
192
+
193
+ Args:
194
+ folders (list[str]): A list of folder path. The order of list should
195
+ be [input_folder, gt_folder].
196
+ keys (list[str]): A list of keys identifying folders. The order should
197
+ be in consistent with folders, e.g., ['lq', 'gt'].
198
+ filename_tmpl (str): Template for each filename. Note that the
199
+ template excludes the file extension. Usually the filename_tmpl is
200
+ for files in the input folder.
201
+
202
+ Returns:
203
+ list[str]: Returned path list.
204
+ """
205
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
206
+ f'But got {len(folders)}')
207
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
208
+ input_folder, gt_folder = folders
209
+ input_key, gt_key = keys
210
+
211
+ input_paths = list(scandir(input_folder))
212
+ gt_paths = list(scandir(gt_folder))
213
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
214
+ f'{len(input_paths)}, {len(gt_paths)}.')
215
+ paths = []
216
+ for gt_path in gt_paths:
217
+ basename, ext = osp.splitext(osp.basename(gt_path))
218
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
219
+ input_path = osp.join(input_folder, input_name)
220
+ assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
221
+ gt_path = osp.join(gt_folder, gt_path)
222
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
223
+ return paths
224
+
225
+
226
+ def paths_from_folder(folder):
227
+ """Generate paths from folder.
228
+
229
+ Args:
230
+ folder (str): Folder path.
231
+
232
+ Returns:
233
+ list[str]: Returned path list.
234
+ """
235
+
236
+ paths = list(scandir(folder))
237
+ paths = [osp.join(folder, path) for path in paths]
238
+ return paths
239
+
240
+
241
+ def paths_from_lmdb(folder):
242
+ """Generate paths from lmdb.
243
+
244
+ Args:
245
+ folder (str): Folder path.
246
+
247
+ Returns:
248
+ list[str]: Returned path list.
249
+ """
250
+ if not folder.endswith('.lmdb'):
251
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
252
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
253
+ paths = [line.split('.')[0] for line in fin]
254
+ return paths
255
+
256
+
257
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
258
+ """Generate Gaussian kernel used in `duf_downsample`.
259
+
260
+ Args:
261
+ kernel_size (int): Kernel size. Default: 13.
262
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
263
+
264
+ Returns:
265
+ np.array: The Gaussian kernel.
266
+ """
267
+ from scipy.ndimage import filters as filters
268
+ kernel = np.zeros((kernel_size, kernel_size))
269
+ # set element at the middle to one, a dirac delta
270
+ kernel[kernel_size // 2, kernel_size // 2] = 1
271
+ # gaussian-smooth the dirac, resulting in a gaussian filter
272
+ return filters.gaussian_filter(kernel, sigma)
273
+
274
+
275
+ def duf_downsample(x, kernel_size=13, scale=4):
276
+ """Downsamping with Gaussian kernel used in the DUF official code.
277
+
278
+ Args:
279
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
280
+ kernel_size (int): Kernel size. Default: 13.
281
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
282
+ Default: 4.
283
+
284
+ Returns:
285
+ Tensor: DUF downsampled frames.
286
+ """
287
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
288
+
289
+ squeeze_flag = False
290
+ if x.ndim == 4:
291
+ squeeze_flag = True
292
+ x = x.unsqueeze(0)
293
+ b, t, c, h, w = x.size()
294
+ x = x.view(-1, 1, h, w)
295
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
296
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
297
+
298
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
299
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
300
+ x = F.conv2d(x, gaussian_filter, stride=scale)
301
+ x = x[:, :, 2:-2, 2:-2]
302
+ x = x.view(b, t, c, x.size(2), x.size(3))
303
+ if squeeze_flag:
304
+ x = x.squeeze(0)
305
+ return x
CodeFormer/basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
CodeFormer/basicsr/data/transforms.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+
4
+
5
+ def mod_crop(img, scale):
6
+ """Mod crop images, used during testing.
7
+
8
+ Args:
9
+ img (ndarray): Input image.
10
+ scale (int): Scale factor.
11
+
12
+ Returns:
13
+ ndarray: Result image.
14
+ """
15
+ img = img.copy()
16
+ if img.ndim in (2, 3):
17
+ h, w = img.shape[0], img.shape[1]
18
+ h_remainder, w_remainder = h % scale, w % scale
19
+ img = img[:h - h_remainder, :w - w_remainder, ...]
20
+ else:
21
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
22
+ return img
23
+
24
+
25
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
26
+ """Paired random crop.
27
+
28
+ It crops lists of lq and gt images with corresponding locations.
29
+
30
+ Args:
31
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
32
+ should have the same shape. If the input is an ndarray, it will
33
+ be transformed to a list containing itself.
34
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
35
+ should have the same shape. If the input is an ndarray, it will
36
+ be transformed to a list containing itself.
37
+ gt_patch_size (int): GT patch size.
38
+ scale (int): Scale factor.
39
+ gt_path (str): Path to ground-truth.
40
+
41
+ Returns:
42
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
43
+ only have one element, just return ndarray.
44
+ """
45
+
46
+ if not isinstance(img_gts, list):
47
+ img_gts = [img_gts]
48
+ if not isinstance(img_lqs, list):
49
+ img_lqs = [img_lqs]
50
+
51
+ h_lq, w_lq, _ = img_lqs[0].shape
52
+ h_gt, w_gt, _ = img_gts[0].shape
53
+ lq_patch_size = gt_patch_size // scale
54
+
55
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
56
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
57
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
58
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
59
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
60
+ f'({lq_patch_size}, {lq_patch_size}). '
61
+ f'Please remove {gt_path}.')
62
+
63
+ # randomly choose top and left coordinates for lq patch
64
+ top = random.randint(0, h_lq - lq_patch_size)
65
+ left = random.randint(0, w_lq - lq_patch_size)
66
+
67
+ # crop lq patch
68
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
69
+
70
+ # crop corresponding gt patch
71
+ top_gt, left_gt = int(top * scale), int(left * scale)
72
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
73
+ if len(img_gts) == 1:
74
+ img_gts = img_gts[0]
75
+ if len(img_lqs) == 1:
76
+ img_lqs = img_lqs[0]
77
+ return img_gts, img_lqs
78
+
79
+
80
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
81
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
82
+
83
+ We use vertical flip and transpose for rotation implementation.
84
+ All the images in the list use the same augmentation.
85
+
86
+ Args:
87
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
88
+ is an ndarray, it will be transformed to a list.
89
+ hflip (bool): Horizontal flip. Default: True.
90
+ rotation (bool): Ratotation. Default: True.
91
+ flows (list[ndarray]: Flows to be augmented. If the input is an
92
+ ndarray, it will be transformed to a list.
93
+ Dimension is (h, w, 2). Default: None.
94
+ return_status (bool): Return the status of flip and rotation.
95
+ Default: False.
96
+
97
+ Returns:
98
+ list[ndarray] | ndarray: Augmented images and flows. If returned
99
+ results only have one element, just return ndarray.
100
+
101
+ """
102
+ hflip = hflip and random.random() < 0.5
103
+ vflip = rotation and random.random() < 0.5
104
+ rot90 = rotation and random.random() < 0.5
105
+
106
+ def _augment(img):
107
+ if hflip: # horizontal
108
+ cv2.flip(img, 1, img)
109
+ if vflip: # vertical
110
+ cv2.flip(img, 0, img)
111
+ if rot90:
112
+ img = img.transpose(1, 0, 2)
113
+ return img
114
+
115
+ def _augment_flow(flow):
116
+ if hflip: # horizontal
117
+ cv2.flip(flow, 1, flow)
118
+ flow[:, :, 0] *= -1
119
+ if vflip: # vertical
120
+ cv2.flip(flow, 0, flow)
121
+ flow[:, :, 1] *= -1
122
+ if rot90:
123
+ flow = flow.transpose(1, 0, 2)
124
+ flow = flow[:, :, [1, 0]]
125
+ return flow
126
+
127
+ if not isinstance(imgs, list):
128
+ imgs = [imgs]
129
+ imgs = [_augment(img) for img in imgs]
130
+ if len(imgs) == 1:
131
+ imgs = imgs[0]
132
+
133
+ if flows is not None:
134
+ if not isinstance(flows, list):
135
+ flows = [flows]
136
+ flows = [_augment_flow(flow) for flow in flows]
137
+ if len(flows) == 1:
138
+ flows = flows[0]
139
+ return imgs, flows
140
+ else:
141
+ if return_status:
142
+ return imgs, (hflip, vflip, rot90)
143
+ else:
144
+ return imgs
145
+
146
+
147
+ def img_rotate(img, angle, center=None, scale=1.0):
148
+ """Rotate image.
149
+
150
+ Args:
151
+ img (ndarray): Image to be rotated.
152
+ angle (float): Rotation angle in degrees. Positive values mean
153
+ counter-clockwise rotation.
154
+ center (tuple[int]): Rotation center. If the center is None,
155
+ initialize it as the center of the image. Default: None.
156
+ scale (float): Isotropic scale factor. Default: 1.0.
157
+ """
158
+ (h, w) = img.shape[:2]
159
+
160
+ if center is None:
161
+ center = (w // 2, h // 2)
162
+
163
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
164
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
165
+ return rotated_img
CodeFormer/basicsr/losses/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils import get_root_logger
4
+ from basicsr.utils.registry import LOSS_REGISTRY
5
+ from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
6
+ gradient_penalty_loss, r1_penalty)
7
+
8
+ __all__ = [
9
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
10
+ 'r1_penalty', 'g_path_regularize'
11
+ ]
12
+
13
+
14
+ def build_loss(opt):
15
+ """Build loss from options.
16
+
17
+ Args:
18
+ opt (dict): Configuration. It must constain:
19
+ type (str): Model type.
20
+ """
21
+ opt = deepcopy(opt)
22
+ loss_type = opt.pop('type')
23
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
24
+ logger = get_root_logger()
25
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26
+ return loss
CodeFormer/basicsr/losses/loss_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def reduce_loss(loss, reduction):
6
+ """Reduce loss as specified.
7
+
8
+ Args:
9
+ loss (Tensor): Elementwise loss tensor.
10
+ reduction (str): Options are 'none', 'mean' and 'sum'.
11
+
12
+ Returns:
13
+ Tensor: Reduced loss tensor.
14
+ """
15
+ reduction_enum = F._Reduction.get_enum(reduction)
16
+ # none: 0, elementwise_mean:1, sum: 2
17
+ if reduction_enum == 0:
18
+ return loss
19
+ elif reduction_enum == 1:
20
+ return loss.mean()
21
+ else:
22
+ return loss.sum()
23
+
24
+
25
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
+ """Apply element-wise weight and reduce loss.
27
+
28
+ Args:
29
+ loss (Tensor): Element-wise loss.
30
+ weight (Tensor): Element-wise weights. Default: None.
31
+ reduction (str): Same as built-in losses of PyTorch. Options are
32
+ 'none', 'mean' and 'sum'. Default: 'mean'.
33
+
34
+ Returns:
35
+ Tensor: Loss values.
36
+ """
37
+ # if weight is specified, apply element-wise weight
38
+ if weight is not None:
39
+ assert weight.dim() == loss.dim()
40
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
+ loss = loss * weight
42
+
43
+ # if weight is not specified or reduction is sum, just reduce the loss
44
+ if weight is None or reduction == 'sum':
45
+ loss = reduce_loss(loss, reduction)
46
+ # if reduction is mean, then compute mean over weight region
47
+ elif reduction == 'mean':
48
+ if weight.size(1) > 1:
49
+ weight = weight.sum()
50
+ else:
51
+ weight = weight.sum() * loss.size(1)
52
+ loss = loss.sum() / weight
53
+
54
+ return loss
55
+
56
+
57
+ def weighted_loss(loss_func):
58
+ """Create a weighted version of a given loss function.
59
+
60
+ To use this decorator, the loss function must have the signature like
61
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
+ element-wise loss without any reduction. This decorator will add weight
63
+ and reduction arguments to the function. The decorated function will have
64
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
+ **kwargs)`.
66
+
67
+ :Example:
68
+
69
+ >>> import torch
70
+ >>> @weighted_loss
71
+ >>> def l1_loss(pred, target):
72
+ >>> return (pred - target).abs()
73
+
74
+ >>> pred = torch.Tensor([0, 2, 3])
75
+ >>> target = torch.Tensor([1, 1, 1])
76
+ >>> weight = torch.Tensor([1, 0, 1])
77
+
78
+ >>> l1_loss(pred, target)
79
+ tensor(1.3333)
80
+ >>> l1_loss(pred, target, weight)
81
+ tensor(1.5000)
82
+ >>> l1_loss(pred, target, reduction='none')
83
+ tensor([1., 1., 2.])
84
+ >>> l1_loss(pred, target, weight, reduction='sum')
85
+ tensor(3.)
86
+ """
87
+
88
+ @functools.wraps(loss_func)
89
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
+ # get element-wise loss
91
+ loss = loss_func(pred, target, **kwargs)
92
+ loss = weight_reduce_loss(loss, weight, reduction)
93
+ return loss
94
+
95
+ return wrapper
CodeFormer/basicsr/losses/losses.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import lpips
3
+ import torch
4
+ from torch import autograd as autograd
5
+ from torch import nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ from basicsr.archs.vgg_arch import VGGFeatureExtractor
9
+ from basicsr.utils.registry import LOSS_REGISTRY
10
+ from .loss_util import weighted_loss
11
+
12
+ _reduction_modes = ['none', 'mean', 'sum']
13
+
14
+
15
+ @weighted_loss
16
+ def l1_loss(pred, target):
17
+ return F.l1_loss(pred, target, reduction='none')
18
+
19
+
20
+ @weighted_loss
21
+ def mse_loss(pred, target):
22
+ return F.mse_loss(pred, target, reduction='none')
23
+
24
+
25
+ @weighted_loss
26
+ def charbonnier_loss(pred, target, eps=1e-12):
27
+ return torch.sqrt((pred - target)**2 + eps)
28
+
29
+
30
+ @LOSS_REGISTRY.register()
31
+ class L1Loss(nn.Module):
32
+ """L1 (mean absolute error, MAE) loss.
33
+
34
+ Args:
35
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
36
+ reduction (str): Specifies the reduction to apply to the output.
37
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
38
+ """
39
+
40
+ def __init__(self, loss_weight=1.0, reduction='mean'):
41
+ super(L1Loss, self).__init__()
42
+ if reduction not in ['none', 'mean', 'sum']:
43
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
44
+
45
+ self.loss_weight = loss_weight
46
+ self.reduction = reduction
47
+
48
+ def forward(self, pred, target, weight=None, **kwargs):
49
+ """
50
+ Args:
51
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
52
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
53
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
54
+ weights. Default: None.
55
+ """
56
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
57
+
58
+
59
+ @LOSS_REGISTRY.register()
60
+ class MSELoss(nn.Module):
61
+ """MSE (L2) loss.
62
+
63
+ Args:
64
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
65
+ reduction (str): Specifies the reduction to apply to the output.
66
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
67
+ """
68
+
69
+ def __init__(self, loss_weight=1.0, reduction='mean'):
70
+ super(MSELoss, self).__init__()
71
+ if reduction not in ['none', 'mean', 'sum']:
72
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
73
+
74
+ self.loss_weight = loss_weight
75
+ self.reduction = reduction
76
+
77
+ def forward(self, pred, target, weight=None, **kwargs):
78
+ """
79
+ Args:
80
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
81
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
82
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
83
+ weights. Default: None.
84
+ """
85
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
86
+
87
+
88
+ @LOSS_REGISTRY.register()
89
+ class CharbonnierLoss(nn.Module):
90
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
91
+ variant of L1Loss).
92
+
93
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
94
+ Super-Resolution".
95
+
96
+ Args:
97
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
98
+ reduction (str): Specifies the reduction to apply to the output.
99
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
100
+ eps (float): A value used to control the curvature near zero.
101
+ Default: 1e-12.
102
+ """
103
+
104
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
105
+ super(CharbonnierLoss, self).__init__()
106
+ if reduction not in ['none', 'mean', 'sum']:
107
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
108
+
109
+ self.loss_weight = loss_weight
110
+ self.reduction = reduction
111
+ self.eps = eps
112
+
113
+ def forward(self, pred, target, weight=None, **kwargs):
114
+ """
115
+ Args:
116
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
117
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
118
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
119
+ weights. Default: None.
120
+ """
121
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
122
+
123
+
124
+ @LOSS_REGISTRY.register()
125
+ class WeightedTVLoss(L1Loss):
126
+ """Weighted TV loss.
127
+
128
+ Args:
129
+ loss_weight (float): Loss weight. Default: 1.0.
130
+ """
131
+
132
+ def __init__(self, loss_weight=1.0):
133
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
134
+
135
+ def forward(self, pred, weight=None):
136
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
137
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
138
+
139
+ loss = x_diff + y_diff
140
+
141
+ return loss
142
+
143
+
144
+ @LOSS_REGISTRY.register()
145
+ class PerceptualLoss(nn.Module):
146
+ """Perceptual loss with commonly used style loss.
147
+
148
+ Args:
149
+ layer_weights (dict): The weight for each layer of vgg feature.
150
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
151
+ feature layer (before relu5_4) will be extracted with weight
152
+ 1.0 in calculting losses.
153
+ vgg_type (str): The type of vgg network used as feature extractor.
154
+ Default: 'vgg19'.
155
+ use_input_norm (bool): If True, normalize the input image in vgg.
156
+ Default: True.
157
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
158
+ Default: False.
159
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
160
+ loss will be calculated and the loss will multiplied by the
161
+ weight. Default: 1.0.
162
+ style_weight (float): If `style_weight > 0`, the style loss will be
163
+ calculated and the loss will multiplied by the weight.
164
+ Default: 0.
165
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
166
+ """
167
+
168
+ def __init__(self,
169
+ layer_weights,
170
+ vgg_type='vgg19',
171
+ use_input_norm=True,
172
+ range_norm=False,
173
+ perceptual_weight=1.0,
174
+ style_weight=0.,
175
+ criterion='l1'):
176
+ super(PerceptualLoss, self).__init__()
177
+ self.perceptual_weight = perceptual_weight
178
+ self.style_weight = style_weight
179
+ self.layer_weights = layer_weights
180
+ self.vgg = VGGFeatureExtractor(
181
+ layer_name_list=list(layer_weights.keys()),
182
+ vgg_type=vgg_type,
183
+ use_input_norm=use_input_norm,
184
+ range_norm=range_norm)
185
+
186
+ self.criterion_type = criterion
187
+ if self.criterion_type == 'l1':
188
+ self.criterion = torch.nn.L1Loss()
189
+ elif self.criterion_type == 'l2':
190
+ self.criterion = torch.nn.L2loss()
191
+ elif self.criterion_type == 'mse':
192
+ self.criterion = torch.nn.MSELoss(reduction='mean')
193
+ elif self.criterion_type == 'fro':
194
+ self.criterion = None
195
+ else:
196
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
197
+
198
+ def forward(self, x, gt):
199
+ """Forward function.
200
+
201
+ Args:
202
+ x (Tensor): Input tensor with shape (n, c, h, w).
203
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
204
+
205
+ Returns:
206
+ Tensor: Forward results.
207
+ """
208
+ # extract vgg features
209
+ x_features = self.vgg(x)
210
+ gt_features = self.vgg(gt.detach())
211
+
212
+ # calculate perceptual loss
213
+ if self.perceptual_weight > 0:
214
+ percep_loss = 0
215
+ for k in x_features.keys():
216
+ if self.criterion_type == 'fro':
217
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
218
+ else:
219
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
220
+ percep_loss *= self.perceptual_weight
221
+ else:
222
+ percep_loss = None
223
+
224
+ # calculate style loss
225
+ if self.style_weight > 0:
226
+ style_loss = 0
227
+ for k in x_features.keys():
228
+ if self.criterion_type == 'fro':
229
+ style_loss += torch.norm(
230
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
231
+ else:
232
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
233
+ gt_features[k])) * self.layer_weights[k]
234
+ style_loss *= self.style_weight
235
+ else:
236
+ style_loss = None
237
+
238
+ return percep_loss, style_loss
239
+
240
+ def _gram_mat(self, x):
241
+ """Calculate Gram matrix.
242
+
243
+ Args:
244
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
245
+
246
+ Returns:
247
+ torch.Tensor: Gram matrix.
248
+ """
249
+ n, c, h, w = x.size()
250
+ features = x.view(n, c, w * h)
251
+ features_t = features.transpose(1, 2)
252
+ gram = features.bmm(features_t) / (c * h * w)
253
+ return gram
254
+
255
+
256
+ @LOSS_REGISTRY.register()
257
+ class LPIPSLoss(nn.Module):
258
+ def __init__(self,
259
+ loss_weight=1.0,
260
+ use_input_norm=True,
261
+ range_norm=False,):
262
+ super(LPIPSLoss, self).__init__()
263
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
264
+ self.loss_weight = loss_weight
265
+ self.use_input_norm = use_input_norm
266
+ self.range_norm = range_norm
267
+
268
+ if self.use_input_norm:
269
+ # the mean is for image with range [0, 1]
270
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
271
+ # the std is for image with range [0, 1]
272
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
273
+
274
+ def forward(self, pred, target):
275
+ if self.range_norm:
276
+ pred = (pred + 1) / 2
277
+ target = (target + 1) / 2
278
+ if self.use_input_norm:
279
+ pred = (pred - self.mean) / self.std
280
+ target = (target - self.mean) / self.std
281
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
282
+ return self.loss_weight * lpips_loss.mean()
283
+
284
+
285
+ @LOSS_REGISTRY.register()
286
+ class GANLoss(nn.Module):
287
+ """Define GAN loss.
288
+
289
+ Args:
290
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
291
+ real_label_val (float): The value for real label. Default: 1.0.
292
+ fake_label_val (float): The value for fake label. Default: 0.0.
293
+ loss_weight (float): Loss weight. Default: 1.0.
294
+ Note that loss_weight is only for generators; and it is always 1.0
295
+ for discriminators.
296
+ """
297
+
298
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
299
+ super(GANLoss, self).__init__()
300
+ self.gan_type = gan_type
301
+ self.loss_weight = loss_weight
302
+ self.real_label_val = real_label_val
303
+ self.fake_label_val = fake_label_val
304
+
305
+ if self.gan_type == 'vanilla':
306
+ self.loss = nn.BCEWithLogitsLoss()
307
+ elif self.gan_type == 'lsgan':
308
+ self.loss = nn.MSELoss()
309
+ elif self.gan_type == 'wgan':
310
+ self.loss = self._wgan_loss
311
+ elif self.gan_type == 'wgan_softplus':
312
+ self.loss = self._wgan_softplus_loss
313
+ elif self.gan_type == 'hinge':
314
+ self.loss = nn.ReLU()
315
+ else:
316
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
317
+
318
+ def _wgan_loss(self, input, target):
319
+ """wgan loss.
320
+
321
+ Args:
322
+ input (Tensor): Input tensor.
323
+ target (bool): Target label.
324
+
325
+ Returns:
326
+ Tensor: wgan loss.
327
+ """
328
+ return -input.mean() if target else input.mean()
329
+
330
+ def _wgan_softplus_loss(self, input, target):
331
+ """wgan loss with soft plus. softplus is a smooth approximation to the
332
+ ReLU function.
333
+
334
+ In StyleGAN2, it is called:
335
+ Logistic loss for discriminator;
336
+ Non-saturating loss for generator.
337
+
338
+ Args:
339
+ input (Tensor): Input tensor.
340
+ target (bool): Target label.
341
+
342
+ Returns:
343
+ Tensor: wgan loss.
344
+ """
345
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
346
+
347
+ def get_target_label(self, input, target_is_real):
348
+ """Get target label.
349
+
350
+ Args:
351
+ input (Tensor): Input tensor.
352
+ target_is_real (bool): Whether the target is real or fake.
353
+
354
+ Returns:
355
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
356
+ return Tensor.
357
+ """
358
+
359
+ if self.gan_type in ['wgan', 'wgan_softplus']:
360
+ return target_is_real
361
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
362
+ return input.new_ones(input.size()) * target_val
363
+
364
+ def forward(self, input, target_is_real, is_disc=False):
365
+ """
366
+ Args:
367
+ input (Tensor): The input for the loss module, i.e., the network
368
+ prediction.
369
+ target_is_real (bool): Whether the targe is real or fake.
370
+ is_disc (bool): Whether the loss for discriminators or not.
371
+ Default: False.
372
+
373
+ Returns:
374
+ Tensor: GAN loss value.
375
+ """
376
+ if self.gan_type == 'hinge':
377
+ if is_disc: # for discriminators in hinge-gan
378
+ input = -input if target_is_real else input
379
+ loss = self.loss(1 + input).mean()
380
+ else: # for generators in hinge-gan
381
+ loss = -input.mean()
382
+ else: # other gan types
383
+ target_label = self.get_target_label(input, target_is_real)
384
+ loss = self.loss(input, target_label)
385
+
386
+ # loss_weight is always 1.0 for discriminators
387
+ return loss if is_disc else loss * self.loss_weight
388
+
389
+
390
+ def r1_penalty(real_pred, real_img):
391
+ """R1 regularization for discriminator. The core idea is to
392
+ penalize the gradient on real data alone: when the
393
+ generator distribution produces the true data distribution
394
+ and the discriminator is equal to 0 on the data manifold, the
395
+ gradient penalty ensures that the discriminator cannot create
396
+ a non-zero gradient orthogonal to the data manifold without
397
+ suffering a loss in the GAN game.
398
+
399
+ Ref:
400
+ Eq. 9 in Which training methods for GANs do actually converge.
401
+ """
402
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
403
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
404
+ return grad_penalty
405
+
406
+
407
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
408
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
409
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
410
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
411
+
412
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
413
+
414
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
415
+
416
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
417
+
418
+
419
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
420
+ """Calculate gradient penalty for wgan-gp.
421
+
422
+ Args:
423
+ discriminator (nn.Module): Network for the discriminator.
424
+ real_data (Tensor): Real input data.
425
+ fake_data (Tensor): Fake input data.
426
+ weight (Tensor): Weight tensor. Default: None.
427
+
428
+ Returns:
429
+ Tensor: A tensor for gradient penalty.
430
+ """
431
+
432
+ batch_size = real_data.size(0)
433
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
434
+
435
+ # interpolate between real_data and fake_data
436
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
437
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
438
+
439
+ disc_interpolates = discriminator(interpolates)
440
+ gradients = autograd.grad(
441
+ outputs=disc_interpolates,
442
+ inputs=interpolates,
443
+ grad_outputs=torch.ones_like(disc_interpolates),
444
+ create_graph=True,
445
+ retain_graph=True,
446
+ only_inputs=True)[0]
447
+
448
+ if weight is not None:
449
+ gradients = gradients * weight
450
+
451
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
452
+ if weight is not None:
453
+ gradients_penalty /= torch.mean(weight)
454
+
455
+ return gradients_penalty
CodeFormer/basicsr/metrics/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils.registry import METRIC_REGISTRY
4
+ from .psnr_ssim import calculate_psnr, calculate_ssim
5
+
6
+ __all__ = ['calculate_psnr', 'calculate_ssim']
7
+
8
+
9
+ def calculate_metric(data, opt):
10
+ """Calculate metric from data and options.
11
+
12
+ Args:
13
+ opt (dict): Configuration. It must constain:
14
+ type (str): Model type.
15
+ """
16
+ opt = deepcopy(opt)
17
+ metric_type = opt.pop('type')
18
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
19
+ return metric
CodeFormer/basicsr/metrics/metric_util.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from basicsr.utils.matlab_functions import bgr2ycbcr
4
+
5
+
6
+ def reorder_image(img, input_order='HWC'):
7
+ """Reorder images to 'HWC' order.
8
+
9
+ If the input_order is (h, w), return (h, w, 1);
10
+ If the input_order is (c, h, w), return (h, w, c);
11
+ If the input_order is (h, w, c), return as it is.
12
+
13
+ Args:
14
+ img (ndarray): Input image.
15
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
16
+ If the input image shape is (h, w), input_order will not have
17
+ effects. Default: 'HWC'.
18
+
19
+ Returns:
20
+ ndarray: reordered image.
21
+ """
22
+
23
+ if input_order not in ['HWC', 'CHW']:
24
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
25
+ if len(img.shape) == 2:
26
+ img = img[..., None]
27
+ if input_order == 'CHW':
28
+ img = img.transpose(1, 2, 0)
29
+ return img
30
+
31
+
32
+ def to_y_channel(img):
33
+ """Change to Y channel of YCbCr.
34
+
35
+ Args:
36
+ img (ndarray): Images with range [0, 255].
37
+
38
+ Returns:
39
+ (ndarray): Images with range [0, 255] (float type) without round.
40
+ """
41
+ img = img.astype(np.float32) / 255.
42
+ if img.ndim == 3 and img.shape[2] == 3:
43
+ img = bgr2ycbcr(img, y_only=True)
44
+ img = img[..., None]
45
+ return img * 255.
CodeFormer/basicsr/metrics/psnr_ssim.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from basicsr.metrics.metric_util import reorder_image, to_y_channel
5
+ from basicsr.utils.registry import METRIC_REGISTRY
6
+
7
+
8
+ @METRIC_REGISTRY.register()
9
+ def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
10
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
11
+
12
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13
+
14
+ Args:
15
+ img1 (ndarray): Images with range [0, 255].
16
+ img2 (ndarray): Images with range [0, 255].
17
+ crop_border (int): Cropped pixels in each edge of an image. These
18
+ pixels are not involved in the PSNR calculation.
19
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
20
+ Default: 'HWC'.
21
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
22
+
23
+ Returns:
24
+ float: psnr result.
25
+ """
26
+
27
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
28
+ if input_order not in ['HWC', 'CHW']:
29
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
30
+ img1 = reorder_image(img1, input_order=input_order)
31
+ img2 = reorder_image(img2, input_order=input_order)
32
+ img1 = img1.astype(np.float64)
33
+ img2 = img2.astype(np.float64)
34
+
35
+ if crop_border != 0:
36
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
37
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
38
+
39
+ if test_y_channel:
40
+ img1 = to_y_channel(img1)
41
+ img2 = to_y_channel(img2)
42
+
43
+ mse = np.mean((img1 - img2)**2)
44
+ if mse == 0:
45
+ return float('inf')
46
+ return 20. * np.log10(255. / np.sqrt(mse))
47
+
48
+
49
+ def _ssim(img1, img2):
50
+ """Calculate SSIM (structural similarity) for one channel images.
51
+
52
+ It is called by func:`calculate_ssim`.
53
+
54
+ Args:
55
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
56
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
57
+
58
+ Returns:
59
+ float: ssim result.
60
+ """
61
+
62
+ C1 = (0.01 * 255)**2
63
+ C2 = (0.03 * 255)**2
64
+
65
+ img1 = img1.astype(np.float64)
66
+ img2 = img2.astype(np.float64)
67
+ kernel = cv2.getGaussianKernel(11, 1.5)
68
+ window = np.outer(kernel, kernel.transpose())
69
+
70
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
71
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
72
+ mu1_sq = mu1**2
73
+ mu2_sq = mu2**2
74
+ mu1_mu2 = mu1 * mu2
75
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
76
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
77
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
78
+
79
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
80
+ return ssim_map.mean()
81
+
82
+
83
+ @METRIC_REGISTRY.register()
84
+ def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
85
+ """Calculate SSIM (structural similarity).
86
+
87
+ Ref:
88
+ Image quality assessment: From error visibility to structural similarity
89
+
90
+ The results are the same as that of the official released MATLAB code in
91
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
92
+
93
+ For three-channel images, SSIM is calculated for each channel and then
94
+ averaged.
95
+
96
+ Args:
97
+ img1 (ndarray): Images with range [0, 255].
98
+ img2 (ndarray): Images with range [0, 255].
99
+ crop_border (int): Cropped pixels in each edge of an image. These
100
+ pixels are not involved in the SSIM calculation.
101
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
102
+ Default: 'HWC'.
103
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
104
+
105
+ Returns:
106
+ float: ssim result.
107
+ """
108
+
109
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
110
+ if input_order not in ['HWC', 'CHW']:
111
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
112
+ img1 = reorder_image(img1, input_order=input_order)
113
+ img2 = reorder_image(img2, input_order=input_order)
114
+ img1 = img1.astype(np.float64)
115
+ img2 = img2.astype(np.float64)
116
+
117
+ if crop_border != 0:
118
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
119
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
120
+
121
+ if test_y_channel:
122
+ img1 = to_y_channel(img1)
123
+ img2 = to_y_channel(img2)
124
+
125
+ ssims = []
126
+ for i in range(img1.shape[2]):
127
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
128
+ return np.array(ssims).mean()
CodeFormer/basicsr/models/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import MODEL_REGISTRY
7
+
8
+ __all__ = ['build_model']
9
+
10
+ # automatically scan and import model modules for registry
11
+ # scan all the files under the 'models' folder and collect files ending with
12
+ # '_model.py'
13
+ model_folder = osp.dirname(osp.abspath(__file__))
14
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
15
+ # import all the model modules
16
+ _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
17
+
18
+
19
+ def build_model(opt):
20
+ """Build model from options.
21
+
22
+ Args:
23
+ opt (dict): Configuration. It must constain:
24
+ model_type (str): Model type.
25
+ """
26
+ opt = deepcopy(opt)
27
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
28
+ logger = get_root_logger()
29
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
30
+ return model
CodeFormer/basicsr/ops/__init__.py ADDED
File without changes
CodeFormer/basicsr/ops/dcn/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
2
+ modulated_deform_conv)
3
+
4
+ __all__ = [
5
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
6
+ 'modulated_deform_conv'
7
+ ]
CodeFormer/basicsr/ops/dcn/deform_conv.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.autograd import Function
5
+ from torch.autograd.function import once_differentiable
6
+ from torch.nn import functional as F
7
+ from torch.nn.modules.utils import _pair, _single
8
+
9
+ try:
10
+ from . import deform_conv_ext
11
+ except ImportError:
12
+ import os
13
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
14
+ if BASICSR_JIT == 'True':
15
+ from torch.utils.cpp_extension import load
16
+ module_path = os.path.dirname(__file__)
17
+ deform_conv_ext = load(
18
+ 'deform_conv',
19
+ sources=[
20
+ os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
21
+ os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
22
+ os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
23
+ ],
24
+ )
25
+
26
+
27
+ class DeformConvFunction(Function):
28
+
29
+ @staticmethod
30
+ def forward(ctx,
31
+ input,
32
+ offset,
33
+ weight,
34
+ stride=1,
35
+ padding=0,
36
+ dilation=1,
37
+ groups=1,
38
+ deformable_groups=1,
39
+ im2col_step=64):
40
+ if input is not None and input.dim() != 4:
41
+ raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
42
+ ctx.stride = _pair(stride)
43
+ ctx.padding = _pair(padding)
44
+ ctx.dilation = _pair(dilation)
45
+ ctx.groups = groups
46
+ ctx.deformable_groups = deformable_groups
47
+ ctx.im2col_step = im2col_step
48
+
49
+ ctx.save_for_backward(input, offset, weight)
50
+
51
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
52
+
53
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
54
+
55
+ if not input.is_cuda:
56
+ raise NotImplementedError
57
+ else:
58
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
59
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
60
+ deform_conv_ext.deform_conv_forward(input, weight,
61
+ offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
62
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
63
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
64
+ ctx.deformable_groups, cur_im2col_step)
65
+ return output
66
+
67
+ @staticmethod
68
+ @once_differentiable
69
+ def backward(ctx, grad_output):
70
+ input, offset, weight = ctx.saved_tensors
71
+
72
+ grad_input = grad_offset = grad_weight = None
73
+
74
+ if not grad_output.is_cuda:
75
+ raise NotImplementedError
76
+ else:
77
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
78
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
79
+
80
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
81
+ grad_input = torch.zeros_like(input)
82
+ grad_offset = torch.zeros_like(offset)
83
+ deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
84
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
85
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
86
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
87
+ ctx.deformable_groups, cur_im2col_step)
88
+
89
+ if ctx.needs_input_grad[2]:
90
+ grad_weight = torch.zeros_like(weight)
91
+ deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
92
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
93
+ weight.size(2), ctx.stride[1], ctx.stride[0],
94
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
95
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
96
+ cur_im2col_step)
97
+
98
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
99
+
100
+ @staticmethod
101
+ def _output_size(input, weight, padding, dilation, stride):
102
+ channels = weight.size(0)
103
+ output_size = (input.size(0), channels)
104
+ for d in range(input.dim() - 2):
105
+ in_size = input.size(d + 2)
106
+ pad = padding[d]
107
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
108
+ stride_ = stride[d]
109
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
110
+ if not all(map(lambda s: s > 0, output_size)):
111
+ raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
112
+ return output_size
113
+
114
+
115
+ class ModulatedDeformConvFunction(Function):
116
+
117
+ @staticmethod
118
+ def forward(ctx,
119
+ input,
120
+ offset,
121
+ mask,
122
+ weight,
123
+ bias=None,
124
+ stride=1,
125
+ padding=0,
126
+ dilation=1,
127
+ groups=1,
128
+ deformable_groups=1):
129
+ ctx.stride = stride
130
+ ctx.padding = padding
131
+ ctx.dilation = dilation
132
+ ctx.groups = groups
133
+ ctx.deformable_groups = deformable_groups
134
+ ctx.with_bias = bias is not None
135
+ if not ctx.with_bias:
136
+ bias = input.new_empty(1) # fake tensor
137
+ if not input.is_cuda:
138
+ raise NotImplementedError
139
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
140
+ or input.requires_grad:
141
+ ctx.save_for_backward(input, offset, mask, weight, bias)
142
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
143
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
144
+ deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
145
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
146
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
147
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
148
+ return output
149
+
150
+ @staticmethod
151
+ @once_differentiable
152
+ def backward(ctx, grad_output):
153
+ if not grad_output.is_cuda:
154
+ raise NotImplementedError
155
+ input, offset, mask, weight, bias = ctx.saved_tensors
156
+ grad_input = torch.zeros_like(input)
157
+ grad_offset = torch.zeros_like(offset)
158
+ grad_mask = torch.zeros_like(mask)
159
+ grad_weight = torch.zeros_like(weight)
160
+ grad_bias = torch.zeros_like(bias)
161
+ deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
162
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
163
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
164
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
165
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
166
+ if not ctx.with_bias:
167
+ grad_bias = None
168
+
169
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
170
+
171
+ @staticmethod
172
+ def _infer_shape(ctx, input, weight):
173
+ n = input.size(0)
174
+ channels_out = weight.size(0)
175
+ height, width = input.shape[2:4]
176
+ kernel_h, kernel_w = weight.shape[2:4]
177
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
178
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
179
+ return n, channels_out, height_out, width_out
180
+
181
+
182
+ deform_conv = DeformConvFunction.apply
183
+ modulated_deform_conv = ModulatedDeformConvFunction.apply
184
+
185
+
186
+ class DeformConv(nn.Module):
187
+
188
+ def __init__(self,
189
+ in_channels,
190
+ out_channels,
191
+ kernel_size,
192
+ stride=1,
193
+ padding=0,
194
+ dilation=1,
195
+ groups=1,
196
+ deformable_groups=1,
197
+ bias=False):
198
+ super(DeformConv, self).__init__()
199
+
200
+ assert not bias
201
+ assert in_channels % groups == 0, \
202
+ f'in_channels {in_channels} is not divisible by groups {groups}'
203
+ assert out_channels % groups == 0, \
204
+ f'out_channels {out_channels} is not divisible ' \
205
+ f'by groups {groups}'
206
+
207
+ self.in_channels = in_channels
208
+ self.out_channels = out_channels
209
+ self.kernel_size = _pair(kernel_size)
210
+ self.stride = _pair(stride)
211
+ self.padding = _pair(padding)
212
+ self.dilation = _pair(dilation)
213
+ self.groups = groups
214
+ self.deformable_groups = deformable_groups
215
+ # enable compatibility with nn.Conv2d
216
+ self.transposed = False
217
+ self.output_padding = _single(0)
218
+
219
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
220
+
221
+ self.reset_parameters()
222
+
223
+ def reset_parameters(self):
224
+ n = self.in_channels
225
+ for k in self.kernel_size:
226
+ n *= k
227
+ stdv = 1. / math.sqrt(n)
228
+ self.weight.data.uniform_(-stdv, stdv)
229
+
230
+ def forward(self, x, offset):
231
+ # To fix an assert error in deform_conv_cuda.cpp:128
232
+ # input image is smaller than kernel
233
+ input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
234
+ if input_pad:
235
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
236
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
237
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
238
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
239
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
240
+ self.deformable_groups)
241
+ if input_pad:
242
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
243
+ return out
244
+
245
+
246
+ class DeformConvPack(DeformConv):
247
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
248
+
249
+ Args:
250
+ in_channels (int): Same as nn.Conv2d.
251
+ out_channels (int): Same as nn.Conv2d.
252
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
253
+ stride (int or tuple[int]): Same as nn.Conv2d.
254
+ padding (int or tuple[int]): Same as nn.Conv2d.
255
+ dilation (int or tuple[int]): Same as nn.Conv2d.
256
+ groups (int): Same as nn.Conv2d.
257
+ bias (bool or str): If specified as `auto`, it will be decided by the
258
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
259
+ False.
260
+ """
261
+
262
+ _version = 2
263
+
264
+ def __init__(self, *args, **kwargs):
265
+ super(DeformConvPack, self).__init__(*args, **kwargs)
266
+
267
+ self.conv_offset = nn.Conv2d(
268
+ self.in_channels,
269
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
270
+ kernel_size=self.kernel_size,
271
+ stride=_pair(self.stride),
272
+ padding=_pair(self.padding),
273
+ dilation=_pair(self.dilation),
274
+ bias=True)
275
+ self.init_offset()
276
+
277
+ def init_offset(self):
278
+ self.conv_offset.weight.data.zero_()
279
+ self.conv_offset.bias.data.zero_()
280
+
281
+ def forward(self, x):
282
+ offset = self.conv_offset(x)
283
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
284
+ self.deformable_groups)
285
+
286
+
287
+ class ModulatedDeformConv(nn.Module):
288
+
289
+ def __init__(self,
290
+ in_channels,
291
+ out_channels,
292
+ kernel_size,
293
+ stride=1,
294
+ padding=0,
295
+ dilation=1,
296
+ groups=1,
297
+ deformable_groups=1,
298
+ bias=True):
299
+ super(ModulatedDeformConv, self).__init__()
300
+ self.in_channels = in_channels
301
+ self.out_channels = out_channels
302
+ self.kernel_size = _pair(kernel_size)
303
+ self.stride = stride
304
+ self.padding = padding
305
+ self.dilation = dilation
306
+ self.groups = groups
307
+ self.deformable_groups = deformable_groups
308
+ self.with_bias = bias
309
+ # enable compatibility with nn.Conv2d
310
+ self.transposed = False
311
+ self.output_padding = _single(0)
312
+
313
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
314
+ if bias:
315
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
316
+ else:
317
+ self.register_parameter('bias', None)
318
+ self.init_weights()
319
+
320
+ def init_weights(self):
321
+ n = self.in_channels
322
+ for k in self.kernel_size:
323
+ n *= k
324
+ stdv = 1. / math.sqrt(n)
325
+ self.weight.data.uniform_(-stdv, stdv)
326
+ if self.bias is not None:
327
+ self.bias.data.zero_()
328
+
329
+ def forward(self, x, offset, mask):
330
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
331
+ self.groups, self.deformable_groups)
332
+
333
+
334
+ class ModulatedDeformConvPack(ModulatedDeformConv):
335
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
336
+
337
+ Args:
338
+ in_channels (int): Same as nn.Conv2d.
339
+ out_channels (int): Same as nn.Conv2d.
340
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
341
+ stride (int or tuple[int]): Same as nn.Conv2d.
342
+ padding (int or tuple[int]): Same as nn.Conv2d.
343
+ dilation (int or tuple[int]): Same as nn.Conv2d.
344
+ groups (int): Same as nn.Conv2d.
345
+ bias (bool or str): If specified as `auto`, it will be decided by the
346
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
347
+ False.
348
+ """
349
+
350
+ _version = 2
351
+
352
+ def __init__(self, *args, **kwargs):
353
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
354
+
355
+ self.conv_offset = nn.Conv2d(
356
+ self.in_channels,
357
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
358
+ kernel_size=self.kernel_size,
359
+ stride=_pair(self.stride),
360
+ padding=_pair(self.padding),
361
+ dilation=_pair(self.dilation),
362
+ bias=True)
363
+ self.init_weights()
364
+
365
+ def init_weights(self):
366
+ super(ModulatedDeformConvPack, self).init_weights()
367
+ if hasattr(self, 'conv_offset'):
368
+ self.conv_offset.weight.data.zero_()
369
+ self.conv_offset.bias.data.zero_()
370
+
371
+ def forward(self, x):
372
+ out = self.conv_offset(x)
373
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
374
+ offset = torch.cat((o1, o2), dim=1)
375
+ mask = torch.sigmoid(mask)
376
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
377
+ self.groups, self.deformable_groups)
CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda.cpp ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // modify from
2
+ // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3
+
4
+ #include <torch/extension.h>
5
+ #include <ATen/DeviceGuard.h>
6
+
7
+ #include <cmath>
8
+ #include <vector>
9
+
10
+ void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
11
+ const int channels, const int height, const int width,
12
+ const int ksize_h, const int ksize_w, const int pad_h,
13
+ const int pad_w, const int stride_h, const int stride_w,
14
+ const int dilation_h, const int dilation_w,
15
+ const int parallel_imgs, const int deformable_group,
16
+ at::Tensor data_col);
17
+
18
+ void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
19
+ const int channels, const int height, const int width,
20
+ const int ksize_h, const int ksize_w, const int pad_h,
21
+ const int pad_w, const int stride_h, const int stride_w,
22
+ const int dilation_h, const int dilation_w,
23
+ const int parallel_imgs, const int deformable_group,
24
+ at::Tensor grad_im);
25
+
26
+ void deformable_col2im_coord(
27
+ const at::Tensor data_col, const at::Tensor data_im,
28
+ const at::Tensor data_offset, const int channels, const int height,
29
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
30
+ const int pad_w, const int stride_h, const int stride_w,
31
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
32
+ const int deformable_group, at::Tensor grad_offset);
33
+
34
+ void modulated_deformable_im2col_cuda(
35
+ const at::Tensor data_im, const at::Tensor data_offset,
36
+ const at::Tensor data_mask, const int batch_size, const int channels,
37
+ const int height_im, const int width_im, const int height_col,
38
+ const int width_col, const int kernel_h, const int kenerl_w,
39
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
40
+ const int dilation_h, const int dilation_w, const int deformable_group,
41
+ at::Tensor data_col);
42
+
43
+ void modulated_deformable_col2im_cuda(
44
+ const at::Tensor data_col, const at::Tensor data_offset,
45
+ const at::Tensor data_mask, const int batch_size, const int channels,
46
+ const int height_im, const int width_im, const int height_col,
47
+ const int width_col, const int kernel_h, const int kenerl_w,
48
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
49
+ const int dilation_h, const int dilation_w, const int deformable_group,
50
+ at::Tensor grad_im);
51
+
52
+ void modulated_deformable_col2im_coord_cuda(
53
+ const at::Tensor data_col, const at::Tensor data_im,
54
+ const at::Tensor data_offset, const at::Tensor data_mask,
55
+ const int batch_size, const int channels, const int height_im,
56
+ const int width_im, const int height_col, const int width_col,
57
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
58
+ const int stride_h, const int stride_w, const int dilation_h,
59
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
60
+ at::Tensor grad_mask);
61
+
62
+ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
63
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
64
+ int padW, int dilationH, int dilationW, int group,
65
+ int deformable_group) {
66
+ TORCH_CHECK(weight.ndimension() == 4,
67
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
68
+ "but got: %s",
69
+ weight.ndimension());
70
+
71
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
72
+
73
+ TORCH_CHECK(kW > 0 && kH > 0,
74
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
75
+ kW);
76
+
77
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
78
+ "kernel size should be consistent with weight, ",
79
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
80
+ kW, weight.size(2), weight.size(3));
81
+
82
+ TORCH_CHECK(dW > 0 && dH > 0,
83
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
84
+
85
+ TORCH_CHECK(
86
+ dilationW > 0 && dilationH > 0,
87
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
88
+ dilationH, dilationW);
89
+
90
+ int ndim = input.ndimension();
91
+ int dimf = 0;
92
+ int dimh = 1;
93
+ int dimw = 2;
94
+
95
+ if (ndim == 4) {
96
+ dimf++;
97
+ dimh++;
98
+ dimw++;
99
+ }
100
+
101
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
102
+ ndim);
103
+
104
+ long nInputPlane = weight.size(1) * group;
105
+ long inputHeight = input.size(dimh);
106
+ long inputWidth = input.size(dimw);
107
+ long nOutputPlane = weight.size(0);
108
+ long outputHeight =
109
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
110
+ long outputWidth =
111
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
112
+
113
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
114
+ "input channels must divide deformable group size");
115
+
116
+ if (outputWidth < 1 || outputHeight < 1)
117
+ AT_ERROR(
118
+ "Given input size: (%ld x %ld x %ld). "
119
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
120
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
121
+ outputWidth);
122
+
123
+ TORCH_CHECK(input.size(1) == nInputPlane,
124
+ "invalid number of input planes, expected: %d, but got: %d",
125
+ nInputPlane, input.size(1));
126
+
127
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
128
+ "input image is smaller than kernel");
129
+
130
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
131
+ "invalid spatial size of offset, expected height: %d width: %d, but "
132
+ "got height: %d width: %d",
133
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
134
+
135
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
136
+ "invalid number of channels of offset");
137
+
138
+ if (gradOutput != NULL) {
139
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
140
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
141
+ nOutputPlane, gradOutput->size(dimf));
142
+
143
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
144
+ gradOutput->size(dimw) == outputWidth),
145
+ "invalid size of gradOutput, expected height: %d width: %d , but "
146
+ "got height: %d width: %d",
147
+ outputHeight, outputWidth, gradOutput->size(dimh),
148
+ gradOutput->size(dimw));
149
+ }
150
+ }
151
+
152
+ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
153
+ at::Tensor offset, at::Tensor output,
154
+ at::Tensor columns, at::Tensor ones, int kW,
155
+ int kH, int dW, int dH, int padW, int padH,
156
+ int dilationW, int dilationH, int group,
157
+ int deformable_group, int im2col_step) {
158
+ // todo: resize columns to include im2col: done
159
+ // todo: add im2col_step as input
160
+ // todo: add new output buffer and transpose it to output (or directly
161
+ // transpose output) todo: possibly change data indexing because of
162
+ // parallel_imgs
163
+
164
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
165
+ dilationH, dilationW, group, deformable_group);
166
+ at::DeviceGuard guard(input.device());
167
+
168
+ input = input.contiguous();
169
+ offset = offset.contiguous();
170
+ weight = weight.contiguous();
171
+
172
+ int batch = 1;
173
+ if (input.ndimension() == 3) {
174
+ // Force batch
175
+ batch = 0;
176
+ input.unsqueeze_(0);
177
+ offset.unsqueeze_(0);
178
+ }
179
+
180
+ // todo: assert batchsize dividable by im2col_step
181
+
182
+ long batchSize = input.size(0);
183
+ long nInputPlane = input.size(1);
184
+ long inputHeight = input.size(2);
185
+ long inputWidth = input.size(3);
186
+
187
+ long nOutputPlane = weight.size(0);
188
+
189
+ long outputWidth =
190
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
191
+ long outputHeight =
192
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
193
+
194
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
195
+
196
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
197
+ outputHeight, outputWidth});
198
+ columns = at::zeros(
199
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
200
+ input.options());
201
+
202
+ if (ones.ndimension() != 2 ||
203
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
204
+ ones = at::ones({outputHeight, outputWidth}, input.options());
205
+ }
206
+
207
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
208
+ inputHeight, inputWidth});
209
+ offset =
210
+ offset.view({batchSize / im2col_step, im2col_step,
211
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
212
+
213
+ at::Tensor output_buffer =
214
+ at::zeros({batchSize / im2col_step, nOutputPlane,
215
+ im2col_step * outputHeight, outputWidth},
216
+ output.options());
217
+
218
+ output_buffer = output_buffer.view(
219
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
220
+ output_buffer.size(2), output_buffer.size(3)});
221
+
222
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
223
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
224
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
225
+ dilationW, im2col_step, deformable_group, columns);
226
+
227
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
228
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
229
+ weight.size(2), weight.size(3)});
230
+
231
+ for (int g = 0; g < group; g++) {
232
+ output_buffer[elt][g] = output_buffer[elt][g]
233
+ .flatten(1)
234
+ .addmm_(weight[g].flatten(1), columns[g])
235
+ .view_as(output_buffer[elt][g]);
236
+ }
237
+ }
238
+
239
+ output_buffer = output_buffer.view(
240
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
241
+ output_buffer.size(3), output_buffer.size(4)});
242
+
243
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
244
+ im2col_step, outputHeight, outputWidth});
245
+ output_buffer.transpose_(1, 2);
246
+ output.copy_(output_buffer);
247
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
248
+
249
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
250
+ offset = offset.view(
251
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
252
+
253
+ if (batch == 0) {
254
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
255
+ input = input.view({nInputPlane, inputHeight, inputWidth});
256
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
257
+ }
258
+
259
+ return 1;
260
+ }
261
+
262
+ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
263
+ at::Tensor gradOutput, at::Tensor gradInput,
264
+ at::Tensor gradOffset, at::Tensor weight,
265
+ at::Tensor columns, int kW, int kH, int dW,
266
+ int dH, int padW, int padH, int dilationW,
267
+ int dilationH, int group,
268
+ int deformable_group, int im2col_step) {
269
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
270
+ dilationH, dilationW, group, deformable_group);
271
+ at::DeviceGuard guard(input.device());
272
+
273
+ input = input.contiguous();
274
+ offset = offset.contiguous();
275
+ gradOutput = gradOutput.contiguous();
276
+ weight = weight.contiguous();
277
+
278
+ int batch = 1;
279
+
280
+ if (input.ndimension() == 3) {
281
+ // Force batch
282
+ batch = 0;
283
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
284
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
285
+ gradOutput = gradOutput.view(
286
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
287
+ }
288
+
289
+ long batchSize = input.size(0);
290
+ long nInputPlane = input.size(1);
291
+ long inputHeight = input.size(2);
292
+ long inputWidth = input.size(3);
293
+
294
+ long nOutputPlane = weight.size(0);
295
+
296
+ long outputWidth =
297
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
298
+ long outputHeight =
299
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
300
+
301
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
302
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
303
+ columns = at::zeros(
304
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
305
+ input.options());
306
+
307
+ // change order of grad output
308
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
309
+ nOutputPlane, outputHeight, outputWidth});
310
+ gradOutput.transpose_(1, 2);
311
+
312
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
313
+ inputHeight, inputWidth});
314
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
315
+ inputHeight, inputWidth});
316
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
317
+ deformable_group * 2 * kH * kW, outputHeight,
318
+ outputWidth});
319
+ offset =
320
+ offset.view({batchSize / im2col_step, im2col_step,
321
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
322
+
323
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
324
+ // divide into groups
325
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
326
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
327
+ weight.size(2), weight.size(3)});
328
+ gradOutput = gradOutput.view(
329
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
330
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
331
+
332
+ for (int g = 0; g < group; g++) {
333
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
334
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
335
+ }
336
+
337
+ columns =
338
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
339
+ gradOutput = gradOutput.view(
340
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
341
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
342
+
343
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
344
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
345
+ dilationH, dilationW, im2col_step, deformable_group,
346
+ gradOffset[elt]);
347
+
348
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
349
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
350
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
351
+ }
352
+
353
+ gradOutput.transpose_(1, 2);
354
+ gradOutput =
355
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
356
+
357
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
358
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
359
+ gradOffset = gradOffset.view(
360
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
361
+ offset = offset.view(
362
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
363
+
364
+ if (batch == 0) {
365
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
366
+ input = input.view({nInputPlane, inputHeight, inputWidth});
367
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
368
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
369
+ gradOffset =
370
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
371
+ }
372
+
373
+ return 1;
374
+ }
375
+
376
+ int deform_conv_backward_parameters_cuda(
377
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
378
+ at::Tensor gradWeight, // at::Tensor gradBias,
379
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
380
+ int padW, int padH, int dilationW, int dilationH, int group,
381
+ int deformable_group, float scale, int im2col_step) {
382
+ // todo: transpose and reshape outGrad
383
+ // todo: reshape columns
384
+ // todo: add im2col_step as input
385
+
386
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
387
+ padW, dilationH, dilationW, group, deformable_group);
388
+ at::DeviceGuard guard(input.device());
389
+
390
+ input = input.contiguous();
391
+ offset = offset.contiguous();
392
+ gradOutput = gradOutput.contiguous();
393
+
394
+ int batch = 1;
395
+
396
+ if (input.ndimension() == 3) {
397
+ // Force batch
398
+ batch = 0;
399
+ input = input.view(
400
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
401
+ gradOutput = gradOutput.view(
402
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
403
+ }
404
+
405
+ long batchSize = input.size(0);
406
+ long nInputPlane = input.size(1);
407
+ long inputHeight = input.size(2);
408
+ long inputWidth = input.size(3);
409
+
410
+ long nOutputPlane = gradWeight.size(0);
411
+
412
+ long outputWidth =
413
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
414
+ long outputHeight =
415
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
416
+
417
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
418
+
419
+ columns = at::zeros(
420
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
421
+ input.options());
422
+
423
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
424
+ nOutputPlane, outputHeight, outputWidth});
425
+ gradOutput.transpose_(1, 2);
426
+
427
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
428
+ gradOutputBuffer =
429
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
430
+ outputHeight, outputWidth});
431
+ gradOutputBuffer.copy_(gradOutput);
432
+ gradOutputBuffer =
433
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
434
+ im2col_step * outputHeight, outputWidth});
435
+
436
+ gradOutput.transpose_(1, 2);
437
+ gradOutput =
438
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
439
+
440
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
441
+ inputHeight, inputWidth});
442
+ offset =
443
+ offset.view({batchSize / im2col_step, im2col_step,
444
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
445
+
446
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
447
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
448
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
449
+ dilationW, im2col_step, deformable_group, columns);
450
+
451
+ // divide into group
452
+ gradOutputBuffer = gradOutputBuffer.view(
453
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
454
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
455
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
456
+ gradWeight =
457
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
458
+ gradWeight.size(2), gradWeight.size(3)});
459
+
460
+ for (int g = 0; g < group; g++) {
461
+ gradWeight[g] = gradWeight[g]
462
+ .flatten(1)
463
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
464
+ columns[g].transpose(1, 0), 1.0, scale)
465
+ .view_as(gradWeight[g]);
466
+ }
467
+ gradOutputBuffer = gradOutputBuffer.view(
468
+ {gradOutputBuffer.size(0),
469
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
470
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
471
+ columns =
472
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
473
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
474
+ gradWeight.size(2), gradWeight.size(3),
475
+ gradWeight.size(4)});
476
+ }
477
+
478
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
479
+ offset = offset.view(
480
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
481
+
482
+ if (batch == 0) {
483
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
484
+ input = input.view({nInputPlane, inputHeight, inputWidth});
485
+ }
486
+
487
+ return 1;
488
+ }
489
+
490
+ void modulated_deform_conv_cuda_forward(
491
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
492
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
493
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
494
+ const int pad_h, const int pad_w, const int dilation_h,
495
+ const int dilation_w, const int group, const int deformable_group,
496
+ const bool with_bias) {
497
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
498
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
499
+ at::DeviceGuard guard(input.device());
500
+
501
+ const int batch = input.size(0);
502
+ const int channels = input.size(1);
503
+ const int height = input.size(2);
504
+ const int width = input.size(3);
505
+
506
+ const int channels_out = weight.size(0);
507
+ const int channels_kernel = weight.size(1);
508
+ const int kernel_h_ = weight.size(2);
509
+ const int kernel_w_ = weight.size(3);
510
+
511
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
512
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
513
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
514
+ if (channels != channels_kernel * group)
515
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
516
+ channels, channels_kernel * group);
517
+
518
+ const int height_out =
519
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
520
+ const int width_out =
521
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
522
+
523
+ if (ones.ndimension() != 2 ||
524
+ ones.size(0) * ones.size(1) < height_out * width_out) {
525
+ // Resize plane and fill with ones...
526
+ ones = at::ones({height_out, width_out}, input.options());
527
+ }
528
+
529
+ // resize output
530
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
531
+ // resize temporary columns
532
+ columns =
533
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
534
+ input.options());
535
+
536
+ output = output.view({output.size(0), group, output.size(1) / group,
537
+ output.size(2), output.size(3)});
538
+
539
+ for (int b = 0; b < batch; b++) {
540
+ modulated_deformable_im2col_cuda(
541
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
542
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
543
+ dilation_h, dilation_w, deformable_group, columns);
544
+
545
+ // divide into group
546
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
547
+ weight.size(2), weight.size(3)});
548
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
549
+
550
+ for (int g = 0; g < group; g++) {
551
+ output[b][g] = output[b][g]
552
+ .flatten(1)
553
+ .addmm_(weight[g].flatten(1), columns[g])
554
+ .view_as(output[b][g]);
555
+ }
556
+
557
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
558
+ weight.size(3), weight.size(4)});
559
+ columns =
560
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
561
+ }
562
+
563
+ output = output.view({output.size(0), output.size(1) * output.size(2),
564
+ output.size(3), output.size(4)});
565
+
566
+ if (with_bias) {
567
+ output += bias.view({1, bias.size(0), 1, 1});
568
+ }
569
+ }
570
+
571
+ void modulated_deform_conv_cuda_backward(
572
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
573
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
574
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
575
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
576
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
577
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
578
+ const bool with_bias) {
579
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
580
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
581
+ at::DeviceGuard guard(input.device());
582
+
583
+ const int batch = input.size(0);
584
+ const int channels = input.size(1);
585
+ const int height = input.size(2);
586
+ const int width = input.size(3);
587
+
588
+ const int channels_kernel = weight.size(1);
589
+ const int kernel_h_ = weight.size(2);
590
+ const int kernel_w_ = weight.size(3);
591
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
592
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
593
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
594
+ if (channels != channels_kernel * group)
595
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
596
+ channels, channels_kernel * group);
597
+
598
+ const int height_out =
599
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
600
+ const int width_out =
601
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
602
+
603
+ if (ones.ndimension() != 2 ||
604
+ ones.size(0) * ones.size(1) < height_out * width_out) {
605
+ // Resize plane and fill with ones...
606
+ ones = at::ones({height_out, width_out}, input.options());
607
+ }
608
+
609
+ grad_input = grad_input.view({batch, channels, height, width});
610
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
611
+ input.options());
612
+
613
+ grad_output =
614
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
615
+ grad_output.size(2), grad_output.size(3)});
616
+
617
+ for (int b = 0; b < batch; b++) {
618
+ // divide int group
619
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
620
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
621
+ weight.size(2), weight.size(3)});
622
+
623
+ for (int g = 0; g < group; g++) {
624
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
625
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
626
+ }
627
+
628
+ columns =
629
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
630
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
631
+ weight.size(3), weight.size(4)});
632
+
633
+ // gradient w.r.t. input coordinate data
634
+ modulated_deformable_col2im_coord_cuda(
635
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
636
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
637
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
638
+ grad_mask[b]);
639
+ // gradient w.r.t. input data
640
+ modulated_deformable_col2im_cuda(
641
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
642
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
643
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
644
+
645
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
646
+ // group
647
+ modulated_deformable_im2col_cuda(
648
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
649
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
650
+ dilation_h, dilation_w, deformable_group, columns);
651
+
652
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
653
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
654
+ grad_weight.size(1), grad_weight.size(2),
655
+ grad_weight.size(3)});
656
+ if (with_bias)
657
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
658
+
659
+ for (int g = 0; g < group; g++) {
660
+ grad_weight[g] =
661
+ grad_weight[g]
662
+ .flatten(1)
663
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
664
+ .view_as(grad_weight[g]);
665
+ if (with_bias) {
666
+ grad_bias[g] =
667
+ grad_bias[g]
668
+ .view({-1, 1})
669
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
670
+ .view(-1);
671
+ }
672
+ }
673
+
674
+ columns =
675
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
676
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
677
+ grad_weight.size(2), grad_weight.size(3),
678
+ grad_weight.size(4)});
679
+ if (with_bias)
680
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
681
+ }
682
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
683
+ grad_output.size(2), grad_output.size(3),
684
+ grad_output.size(4)});
685
+ }
CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3
+ *
4
+ * COPYRIGHT
5
+ *
6
+ * All contributions by the University of California:
7
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8
+ * All rights reserved.
9
+ *
10
+ * All other contributions:
11
+ * Copyright (c) 2014-2017, the respective contributors
12
+ * All rights reserved.
13
+ *
14
+ * Caffe uses a shared copyright model: each contributor holds copyright over
15
+ * their contributions to Caffe. The project versioning records all such
16
+ * contribution and copyright details. If a contributor wants to further mark
17
+ * their specific copyright on a particular contribution, they should indicate
18
+ * their copyright solely in the commit message of the change when it is
19
+ * committed.
20
+ *
21
+ * LICENSE
22
+ *
23
+ * Redistribution and use in source and binary forms, with or without
24
+ * modification, are permitted provided that the following conditions are met:
25
+ *
26
+ * 1. Redistributions of source code must retain the above copyright notice, this
27
+ * list of conditions and the following disclaimer.
28
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
29
+ * this list of conditions and the following disclaimer in the documentation
30
+ * and/or other materials provided with the distribution.
31
+ *
32
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42
+ *
43
+ * CONTRIBUTION AGREEMENT
44
+ *
45
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
46
+ * or otherwise, the contributor releases their content to the
47
+ * license and copyright terms herein.
48
+ *
49
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
50
+ *
51
+ * Copyright (c) 2018 Microsoft
52
+ * Licensed under The MIT License [see LICENSE for details]
53
+ * \file modulated_deformable_im2col.cuh
54
+ * \brief Function definitions of converting an image to
55
+ * column matrix based on kernel, padding, dilation, and offset.
56
+ * These functions are mainly used in deformable convolution operators.
57
+ * \ref: https://arxiv.org/abs/1703.06211
58
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
59
+ */
60
+
61
+ // modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
62
+
63
+ #include <ATen/ATen.h>
64
+ #include <ATen/cuda/CUDAContext.h>
65
+ #include <THC/THCAtomics.cuh>
66
+ #include <stdio.h>
67
+ #include <math.h>
68
+ #include <float.h>
69
+
70
+ using namespace at;
71
+
72
+ #define CUDA_KERNEL_LOOP(i, n) \
73
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
74
+ i += blockDim.x * gridDim.x)
75
+
76
+ const int CUDA_NUM_THREADS = 1024;
77
+ const int kMaxGridNum = 65535;
78
+
79
+ inline int GET_BLOCKS(const int N)
80
+ {
81
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
82
+ }
83
+
84
+ template <typename scalar_t>
85
+ __device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
86
+ const int height, const int width, scalar_t h, scalar_t w)
87
+ {
88
+
89
+ int h_low = floor(h);
90
+ int w_low = floor(w);
91
+ int h_high = h_low + 1;
92
+ int w_high = w_low + 1;
93
+
94
+ scalar_t lh = h - h_low;
95
+ scalar_t lw = w - w_low;
96
+ scalar_t hh = 1 - lh, hw = 1 - lw;
97
+
98
+ scalar_t v1 = 0;
99
+ if (h_low >= 0 && w_low >= 0)
100
+ v1 = bottom_data[h_low * data_width + w_low];
101
+ scalar_t v2 = 0;
102
+ if (h_low >= 0 && w_high <= width - 1)
103
+ v2 = bottom_data[h_low * data_width + w_high];
104
+ scalar_t v3 = 0;
105
+ if (h_high <= height - 1 && w_low >= 0)
106
+ v3 = bottom_data[h_high * data_width + w_low];
107
+ scalar_t v4 = 0;
108
+ if (h_high <= height - 1 && w_high <= width - 1)
109
+ v4 = bottom_data[h_high * data_width + w_high];
110
+
111
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
112
+
113
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
114
+ return val;
115
+ }
116
+
117
+ template <typename scalar_t>
118
+ __device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
119
+ const int h, const int w, const int height, const int width)
120
+ {
121
+
122
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
123
+ {
124
+ //empty
125
+ return 0;
126
+ }
127
+
128
+ int argmax_h_low = floor(argmax_h);
129
+ int argmax_w_low = floor(argmax_w);
130
+ int argmax_h_high = argmax_h_low + 1;
131
+ int argmax_w_high = argmax_w_low + 1;
132
+
133
+ scalar_t weight = 0;
134
+ if (h == argmax_h_low && w == argmax_w_low)
135
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
136
+ if (h == argmax_h_low && w == argmax_w_high)
137
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
138
+ if (h == argmax_h_high && w == argmax_w_low)
139
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
140
+ if (h == argmax_h_high && w == argmax_w_high)
141
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
142
+ return weight;
143
+ }
144
+
145
+ template <typename scalar_t>
146
+ __device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
147
+ const int height, const int width, const scalar_t *im_data,
148
+ const int data_width, const int bp_dir)
149
+ {
150
+
151
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
152
+ {
153
+ //empty
154
+ return 0;
155
+ }
156
+
157
+ int argmax_h_low = floor(argmax_h);
158
+ int argmax_w_low = floor(argmax_w);
159
+ int argmax_h_high = argmax_h_low + 1;
160
+ int argmax_w_high = argmax_w_low + 1;
161
+
162
+ scalar_t weight = 0;
163
+
164
+ if (bp_dir == 0)
165
+ {
166
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
167
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
168
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
169
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
170
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
171
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
172
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
173
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
174
+ }
175
+ else if (bp_dir == 1)
176
+ {
177
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
178
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
179
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
180
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
181
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
182
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
183
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
184
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
185
+ }
186
+
187
+ return weight;
188
+ }
189
+
190
+ template <typename scalar_t>
191
+ __global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
192
+ const int height, const int width, const int kernel_h, const int kernel_w,
193
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
194
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
195
+ const int batch_size, const int num_channels, const int deformable_group,
196
+ const int height_col, const int width_col,
197
+ scalar_t *data_col)
198
+ {
199
+ CUDA_KERNEL_LOOP(index, n)
200
+ {
201
+ // index index of output matrix
202
+ const int w_col = index % width_col;
203
+ const int h_col = (index / width_col) % height_col;
204
+ const int b_col = (index / width_col / height_col) % batch_size;
205
+ const int c_im = (index / width_col / height_col) / batch_size;
206
+ const int c_col = c_im * kernel_h * kernel_w;
207
+
208
+ // compute deformable group index
209
+ const int deformable_group_index = c_im / channel_per_deformable_group;
210
+
211
+ const int h_in = h_col * stride_h - pad_h;
212
+ const int w_in = w_col * stride_w - pad_w;
213
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
214
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
215
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
216
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
217
+
218
+ for (int i = 0; i < kernel_h; ++i)
219
+ {
220
+ for (int j = 0; j < kernel_w; ++j)
221
+ {
222
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
223
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
224
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
225
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
226
+ scalar_t val = static_cast<scalar_t>(0);
227
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
228
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
229
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
230
+ {
231
+ //const scalar_t map_h = i * dilation_h + offset_h;
232
+ //const scalar_t map_w = j * dilation_w + offset_w;
233
+ //const int cur_height = height - h_in;
234
+ //const int cur_width = width - w_in;
235
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
236
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
237
+ }
238
+ *data_col_ptr = val;
239
+ data_col_ptr += batch_size * height_col * width_col;
240
+ }
241
+ }
242
+ }
243
+ }
244
+
245
+ void deformable_im2col(
246
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
247
+ const int height, const int width, const int ksize_h, const int ksize_w,
248
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
249
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
250
+ const int deformable_group, at::Tensor data_col)
251
+ {
252
+ // num_axes should be smaller than block size
253
+ // todo: check parallel_imgs is correctly passed in
254
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
255
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
256
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
257
+ int channel_per_deformable_group = channels / deformable_group;
258
+
259
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
260
+ data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
261
+ const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
262
+ const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
263
+ scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
264
+
265
+ deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
266
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
267
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
268
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
269
+ height_col, width_col, data_col_);
270
+ }));
271
+
272
+ cudaError_t err = cudaGetLastError();
273
+ if (err != cudaSuccess)
274
+ {
275
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
276
+ }
277
+ }
278
+
279
+ template <typename scalar_t>
280
+ __global__ void deformable_col2im_gpu_kernel(
281
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
282
+ const int channels, const int height, const int width,
283
+ const int kernel_h, const int kernel_w,
284
+ const int pad_h, const int pad_w,
285
+ const int stride_h, const int stride_w,
286
+ const int dilation_h, const int dilation_w,
287
+ const int channel_per_deformable_group,
288
+ const int batch_size, const int deformable_group,
289
+ const int height_col, const int width_col,
290
+ scalar_t *grad_im)
291
+ {
292
+ CUDA_KERNEL_LOOP(index, n)
293
+ {
294
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
295
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
296
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
297
+ // compute the start and end of the output
298
+
299
+ const int deformable_group_index = c / channel_per_deformable_group;
300
+
301
+ int w_out = index % width_col;
302
+ int h_out = (index / width_col) % height_col;
303
+ int b = (index / width_col / height_col) % batch_size;
304
+ int w_in = w_out * stride_w - pad_w;
305
+ int h_in = h_out * stride_h - pad_h;
306
+
307
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
308
+ 2 * kernel_h * kernel_w * height_col * width_col;
309
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
310
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
311
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
312
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
313
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
314
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
315
+
316
+ const scalar_t cur_top_grad = data_col[index];
317
+ const int cur_h = (int)cur_inv_h_data;
318
+ const int cur_w = (int)cur_inv_w_data;
319
+ for (int dy = -2; dy <= 2; dy++)
320
+ {
321
+ for (int dx = -2; dx <= 2; dx++)
322
+ {
323
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
324
+ cur_w + dx >= 0 && cur_w + dx < width &&
325
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
326
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
327
+ {
328
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
329
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
330
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
331
+ }
332
+ }
333
+ }
334
+ }
335
+ }
336
+
337
+ void deformable_col2im(
338
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
339
+ const int height, const int width, const int ksize_h,
340
+ const int ksize_w, const int pad_h, const int pad_w,
341
+ const int stride_h, const int stride_w,
342
+ const int dilation_h, const int dilation_w,
343
+ const int parallel_imgs, const int deformable_group,
344
+ at::Tensor grad_im)
345
+ {
346
+
347
+ // todo: make sure parallel_imgs is passed in correctly
348
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
349
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
350
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
351
+ int channel_per_deformable_group = channels / deformable_group;
352
+
353
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
354
+ data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
355
+ const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
356
+ const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
357
+ scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
358
+
359
+ deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
360
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
361
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
362
+ dilation_h, dilation_w, channel_per_deformable_group,
363
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
364
+ }));
365
+
366
+ cudaError_t err = cudaGetLastError();
367
+ if (err != cudaSuccess)
368
+ {
369
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
370
+ }
371
+ }
372
+
373
+ template <typename scalar_t>
374
+ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
375
+ const scalar_t *data_im, const scalar_t *data_offset,
376
+ const int channels, const int height, const int width,
377
+ const int kernel_h, const int kernel_w,
378
+ const int pad_h, const int pad_w,
379
+ const int stride_h, const int stride_w,
380
+ const int dilation_h, const int dilation_w,
381
+ const int channel_per_deformable_group,
382
+ const int batch_size, const int offset_channels, const int deformable_group,
383
+ const int height_col, const int width_col, scalar_t *grad_offset)
384
+ {
385
+ CUDA_KERNEL_LOOP(index, n)
386
+ {
387
+ scalar_t val = 0;
388
+ int w = index % width_col;
389
+ int h = (index / width_col) % height_col;
390
+ int c = (index / width_col / height_col) % offset_channels;
391
+ int b = (index / width_col / height_col) / offset_channels;
392
+ // compute the start and end of the output
393
+
394
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
395
+ const int col_step = kernel_h * kernel_w;
396
+ int cnt = 0;
397
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
398
+ batch_size * width_col * height_col;
399
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
400
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
401
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
402
+ kernel_h * kernel_w * height_col * width_col;
403
+
404
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
405
+
406
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
407
+ {
408
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
409
+ const int bp_dir = offset_c % 2;
410
+
411
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
412
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
413
+ int w_out = col_pos % width_col;
414
+ int h_out = (col_pos / width_col) % height_col;
415
+ int w_in = w_out * stride_w - pad_w;
416
+ int h_in = h_out * stride_h - pad_h;
417
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
418
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
419
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
420
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
421
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
422
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
423
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
424
+ {
425
+ inv_h = inv_w = -2;
426
+ }
427
+ const scalar_t weight = get_coordinate_weight(
428
+ inv_h, inv_w,
429
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
430
+ val += weight * data_col_ptr[col_pos];
431
+ cnt += 1;
432
+ }
433
+
434
+ grad_offset[index] = val;
435
+ }
436
+ }
437
+
438
+ void deformable_col2im_coord(
439
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
440
+ const int channels, const int height, const int width, const int ksize_h,
441
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
442
+ const int stride_w, const int dilation_h, const int dilation_w,
443
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
444
+ {
445
+
446
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
447
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
448
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
449
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
450
+
451
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
452
+ data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
453
+ const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
454
+ const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
455
+ const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
456
+ scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
457
+
458
+ deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
459
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
460
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
461
+ dilation_h, dilation_w, channel_per_deformable_group,
462
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
463
+ height_col, width_col, grad_offset_);
464
+ }));
465
+ }
466
+
467
+ template <typename scalar_t>
468
+ __device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
469
+ const int height, const int width, scalar_t h, scalar_t w)
470
+ {
471
+ int h_low = floor(h);
472
+ int w_low = floor(w);
473
+ int h_high = h_low + 1;
474
+ int w_high = w_low + 1;
475
+
476
+ scalar_t lh = h - h_low;
477
+ scalar_t lw = w - w_low;
478
+ scalar_t hh = 1 - lh, hw = 1 - lw;
479
+
480
+ scalar_t v1 = 0;
481
+ if (h_low >= 0 && w_low >= 0)
482
+ v1 = bottom_data[h_low * data_width + w_low];
483
+ scalar_t v2 = 0;
484
+ if (h_low >= 0 && w_high <= width - 1)
485
+ v2 = bottom_data[h_low * data_width + w_high];
486
+ scalar_t v3 = 0;
487
+ if (h_high <= height - 1 && w_low >= 0)
488
+ v3 = bottom_data[h_high * data_width + w_low];
489
+ scalar_t v4 = 0;
490
+ if (h_high <= height - 1 && w_high <= width - 1)
491
+ v4 = bottom_data[h_high * data_width + w_high];
492
+
493
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
494
+
495
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
496
+ return val;
497
+ }
498
+
499
+ template <typename scalar_t>
500
+ __device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
501
+ const int h, const int w, const int height, const int width)
502
+ {
503
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
504
+ {
505
+ //empty
506
+ return 0;
507
+ }
508
+
509
+ int argmax_h_low = floor(argmax_h);
510
+ int argmax_w_low = floor(argmax_w);
511
+ int argmax_h_high = argmax_h_low + 1;
512
+ int argmax_w_high = argmax_w_low + 1;
513
+
514
+ scalar_t weight = 0;
515
+ if (h == argmax_h_low && w == argmax_w_low)
516
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
517
+ if (h == argmax_h_low && w == argmax_w_high)
518
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
519
+ if (h == argmax_h_high && w == argmax_w_low)
520
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
521
+ if (h == argmax_h_high && w == argmax_w_high)
522
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
523
+ return weight;
524
+ }
525
+
526
+ template <typename scalar_t>
527
+ __device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
528
+ const int height, const int width, const scalar_t *im_data,
529
+ const int data_width, const int bp_dir)
530
+ {
531
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
532
+ {
533
+ //empty
534
+ return 0;
535
+ }
536
+
537
+ int argmax_h_low = floor(argmax_h);
538
+ int argmax_w_low = floor(argmax_w);
539
+ int argmax_h_high = argmax_h_low + 1;
540
+ int argmax_w_high = argmax_w_low + 1;
541
+
542
+ scalar_t weight = 0;
543
+
544
+ if (bp_dir == 0)
545
+ {
546
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
547
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
548
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
549
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
550
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
551
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
552
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
553
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
554
+ }
555
+ else if (bp_dir == 1)
556
+ {
557
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
558
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
559
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
560
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
561
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
562
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
563
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
564
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
565
+ }
566
+
567
+ return weight;
568
+ }
569
+
570
+ template <typename scalar_t>
571
+ __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
572
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
573
+ const int height, const int width, const int kernel_h, const int kernel_w,
574
+ const int pad_h, const int pad_w,
575
+ const int stride_h, const int stride_w,
576
+ const int dilation_h, const int dilation_w,
577
+ const int channel_per_deformable_group,
578
+ const int batch_size, const int num_channels, const int deformable_group,
579
+ const int height_col, const int width_col,
580
+ scalar_t *data_col)
581
+ {
582
+ CUDA_KERNEL_LOOP(index, n)
583
+ {
584
+ // index index of output matrix
585
+ const int w_col = index % width_col;
586
+ const int h_col = (index / width_col) % height_col;
587
+ const int b_col = (index / width_col / height_col) % batch_size;
588
+ const int c_im = (index / width_col / height_col) / batch_size;
589
+ const int c_col = c_im * kernel_h * kernel_w;
590
+
591
+ // compute deformable group index
592
+ const int deformable_group_index = c_im / channel_per_deformable_group;
593
+
594
+ const int h_in = h_col * stride_h - pad_h;
595
+ const int w_in = w_col * stride_w - pad_w;
596
+
597
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
598
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
599
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
600
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
601
+
602
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
603
+
604
+ for (int i = 0; i < kernel_h; ++i)
605
+ {
606
+ for (int j = 0; j < kernel_w; ++j)
607
+ {
608
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
609
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
610
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
611
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
612
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
613
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
614
+ scalar_t val = static_cast<scalar_t>(0);
615
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
616
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
617
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
618
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
619
+ {
620
+ //const float map_h = i * dilation_h + offset_h;
621
+ //const float map_w = j * dilation_w + offset_w;
622
+ //const int cur_height = height - h_in;
623
+ //const int cur_width = width - w_in;
624
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
625
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
626
+ }
627
+ *data_col_ptr = val * mask;
628
+ data_col_ptr += batch_size * height_col * width_col;
629
+ //data_col_ptr += height_col * width_col;
630
+ }
631
+ }
632
+ }
633
+ }
634
+
635
+ template <typename scalar_t>
636
+ __global__ void modulated_deformable_col2im_gpu_kernel(const int n,
637
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
638
+ const int channels, const int height, const int width,
639
+ const int kernel_h, const int kernel_w,
640
+ const int pad_h, const int pad_w,
641
+ const int stride_h, const int stride_w,
642
+ const int dilation_h, const int dilation_w,
643
+ const int channel_per_deformable_group,
644
+ const int batch_size, const int deformable_group,
645
+ const int height_col, const int width_col,
646
+ scalar_t *grad_im)
647
+ {
648
+ CUDA_KERNEL_LOOP(index, n)
649
+ {
650
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
651
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
652
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
653
+ // compute the start and end of the output
654
+
655
+ const int deformable_group_index = c / channel_per_deformable_group;
656
+
657
+ int w_out = index % width_col;
658
+ int h_out = (index / width_col) % height_col;
659
+ int b = (index / width_col / height_col) % batch_size;
660
+ int w_in = w_out * stride_w - pad_w;
661
+ int h_in = h_out * stride_h - pad_h;
662
+
663
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
664
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
665
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
666
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
667
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
668
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
669
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
670
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
671
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
672
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
673
+
674
+ const scalar_t cur_top_grad = data_col[index] * mask;
675
+ const int cur_h = (int)cur_inv_h_data;
676
+ const int cur_w = (int)cur_inv_w_data;
677
+ for (int dy = -2; dy <= 2; dy++)
678
+ {
679
+ for (int dx = -2; dx <= 2; dx++)
680
+ {
681
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
682
+ cur_w + dx >= 0 && cur_w + dx < width &&
683
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
684
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
685
+ {
686
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
687
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
688
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
689
+ }
690
+ }
691
+ }
692
+ }
693
+ }
694
+
695
+ template <typename scalar_t>
696
+ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
697
+ const scalar_t *data_col, const scalar_t *data_im,
698
+ const scalar_t *data_offset, const scalar_t *data_mask,
699
+ const int channels, const int height, const int width,
700
+ const int kernel_h, const int kernel_w,
701
+ const int pad_h, const int pad_w,
702
+ const int stride_h, const int stride_w,
703
+ const int dilation_h, const int dilation_w,
704
+ const int channel_per_deformable_group,
705
+ const int batch_size, const int offset_channels, const int deformable_group,
706
+ const int height_col, const int width_col,
707
+ scalar_t *grad_offset, scalar_t *grad_mask)
708
+ {
709
+ CUDA_KERNEL_LOOP(index, n)
710
+ {
711
+ scalar_t val = 0, mval = 0;
712
+ int w = index % width_col;
713
+ int h = (index / width_col) % height_col;
714
+ int c = (index / width_col / height_col) % offset_channels;
715
+ int b = (index / width_col / height_col) / offset_channels;
716
+ // compute the start and end of the output
717
+
718
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
719
+ const int col_step = kernel_h * kernel_w;
720
+ int cnt = 0;
721
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
722
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
723
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
724
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
725
+
726
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
727
+
728
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
729
+ {
730
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
731
+ const int bp_dir = offset_c % 2;
732
+
733
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
734
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
735
+ int w_out = col_pos % width_col;
736
+ int h_out = (col_pos / width_col) % height_col;
737
+ int w_in = w_out * stride_w - pad_w;
738
+ int h_in = h_out * stride_h - pad_h;
739
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
740
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
741
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
742
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
743
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
744
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
745
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
746
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
747
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
748
+ {
749
+ inv_h = inv_w = -2;
750
+ }
751
+ else
752
+ {
753
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
754
+ }
755
+ const scalar_t weight = dmcn_get_coordinate_weight(
756
+ inv_h, inv_w,
757
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
758
+ val += weight * data_col_ptr[col_pos] * mask;
759
+ cnt += 1;
760
+ }
761
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
762
+ grad_offset[index] = val;
763
+ if (offset_c % 2 == 0)
764
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
765
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
766
+ }
767
+ }
768
+
769
+ void modulated_deformable_im2col_cuda(
770
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
771
+ const int batch_size, const int channels, const int height_im, const int width_im,
772
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
773
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
774
+ const int dilation_h, const int dilation_w,
775
+ const int deformable_group, at::Tensor data_col)
776
+ {
777
+ // num_axes should be smaller than block size
778
+ const int channel_per_deformable_group = channels / deformable_group;
779
+ const int num_kernels = channels * batch_size * height_col * width_col;
780
+
781
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
782
+ data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
783
+ const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
784
+ const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
785
+ const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
786
+ scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
787
+
788
+ modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
789
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
790
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
791
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
792
+ }));
793
+
794
+ cudaError_t err = cudaGetLastError();
795
+ if (err != cudaSuccess)
796
+ {
797
+ printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
798
+ }
799
+ }
800
+
801
+ void modulated_deformable_col2im_cuda(
802
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
803
+ const int batch_size, const int channels, const int height_im, const int width_im,
804
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
805
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
806
+ const int dilation_h, const int dilation_w,
807
+ const int deformable_group, at::Tensor grad_im)
808
+ {
809
+
810
+ const int channel_per_deformable_group = channels / deformable_group;
811
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
812
+
813
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
814
+ data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
815
+ const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
816
+ const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
817
+ const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
818
+ scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
819
+
820
+ modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
821
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
822
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
823
+ dilation_h, dilation_w, channel_per_deformable_group,
824
+ batch_size, deformable_group, height_col, width_col, grad_im_);
825
+ }));
826
+
827
+ cudaError_t err = cudaGetLastError();
828
+ if (err != cudaSuccess)
829
+ {
830
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
831
+ }
832
+ }
833
+
834
+ void modulated_deformable_col2im_coord_cuda(
835
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
836
+ const int batch_size, const int channels, const int height_im, const int width_im,
837
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
838
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
839
+ const int dilation_h, const int dilation_w,
840
+ const int deformable_group,
841
+ at::Tensor grad_offset, at::Tensor grad_mask)
842
+ {
843
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
844
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
845
+
846
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
847
+ data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
848
+ const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
849
+ const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
850
+ const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
851
+ const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
852
+ scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
853
+ scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
854
+
855
+ modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
856
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
857
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
858
+ dilation_h, dilation_w, channel_per_deformable_group,
859
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
860
+ grad_offset_, grad_mask_);
861
+ }));
862
+ cudaError_t err = cudaGetLastError();
863
+ if (err != cudaSuccess)
864
+ {
865
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
866
+ }
867
+ }
CodeFormer/basicsr/ops/dcn/src/deform_conv_ext.cpp ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // modify from
2
+ // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3
+
4
+ #include <torch/extension.h>
5
+ #include <ATen/DeviceGuard.h>
6
+
7
+ #include <cmath>
8
+ #include <vector>
9
+
10
+ #define WITH_CUDA // always use cuda
11
+ #ifdef WITH_CUDA
12
+ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
13
+ at::Tensor offset, at::Tensor output,
14
+ at::Tensor columns, at::Tensor ones, int kW,
15
+ int kH, int dW, int dH, int padW, int padH,
16
+ int dilationW, int dilationH, int group,
17
+ int deformable_group, int im2col_step);
18
+
19
+ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
20
+ at::Tensor gradOutput, at::Tensor gradInput,
21
+ at::Tensor gradOffset, at::Tensor weight,
22
+ at::Tensor columns, int kW, int kH, int dW,
23
+ int dH, int padW, int padH, int dilationW,
24
+ int dilationH, int group,
25
+ int deformable_group, int im2col_step);
26
+
27
+ int deform_conv_backward_parameters_cuda(
28
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
29
+ at::Tensor gradWeight, // at::Tensor gradBias,
30
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
31
+ int padW, int padH, int dilationW, int dilationH, int group,
32
+ int deformable_group, float scale, int im2col_step);
33
+
34
+ void modulated_deform_conv_cuda_forward(
35
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
36
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
37
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
38
+ const int pad_h, const int pad_w, const int dilation_h,
39
+ const int dilation_w, const int group, const int deformable_group,
40
+ const bool with_bias);
41
+
42
+ void modulated_deform_conv_cuda_backward(
43
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
44
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
45
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
46
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
47
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
48
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
49
+ const bool with_bias);
50
+ #endif
51
+
52
+ int deform_conv_forward(at::Tensor input, at::Tensor weight,
53
+ at::Tensor offset, at::Tensor output,
54
+ at::Tensor columns, at::Tensor ones, int kW,
55
+ int kH, int dW, int dH, int padW, int padH,
56
+ int dilationW, int dilationH, int group,
57
+ int deformable_group, int im2col_step) {
58
+ if (input.device().is_cuda()) {
59
+ #ifdef WITH_CUDA
60
+ return deform_conv_forward_cuda(input, weight, offset, output, columns,
61
+ ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
62
+ deformable_group, im2col_step);
63
+ #else
64
+ AT_ERROR("deform conv is not compiled with GPU support");
65
+ #endif
66
+ }
67
+ AT_ERROR("deform conv is not implemented on CPU");
68
+ }
69
+
70
+ int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
71
+ at::Tensor gradOutput, at::Tensor gradInput,
72
+ at::Tensor gradOffset, at::Tensor weight,
73
+ at::Tensor columns, int kW, int kH, int dW,
74
+ int dH, int padW, int padH, int dilationW,
75
+ int dilationH, int group,
76
+ int deformable_group, int im2col_step) {
77
+ if (input.device().is_cuda()) {
78
+ #ifdef WITH_CUDA
79
+ return deform_conv_backward_input_cuda(input, offset, gradOutput,
80
+ gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
81
+ dilationW, dilationH, group, deformable_group, im2col_step);
82
+ #else
83
+ AT_ERROR("deform conv is not compiled with GPU support");
84
+ #endif
85
+ }
86
+ AT_ERROR("deform conv is not implemented on CPU");
87
+ }
88
+
89
+ int deform_conv_backward_parameters(
90
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
91
+ at::Tensor gradWeight, // at::Tensor gradBias,
92
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
93
+ int padW, int padH, int dilationW, int dilationH, int group,
94
+ int deformable_group, float scale, int im2col_step) {
95
+ if (input.device().is_cuda()) {
96
+ #ifdef WITH_CUDA
97
+ return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
98
+ gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
99
+ dilationH, group, deformable_group, scale, im2col_step);
100
+ #else
101
+ AT_ERROR("deform conv is not compiled with GPU support");
102
+ #endif
103
+ }
104
+ AT_ERROR("deform conv is not implemented on CPU");
105
+ }
106
+
107
+ void modulated_deform_conv_forward(
108
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
109
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
110
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
111
+ const int pad_h, const int pad_w, const int dilation_h,
112
+ const int dilation_w, const int group, const int deformable_group,
113
+ const bool with_bias) {
114
+ if (input.device().is_cuda()) {
115
+ #ifdef WITH_CUDA
116
+ return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
117
+ offset, mask, output, columns, kernel_h, kernel_w, stride_h,
118
+ stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
119
+ deformable_group, with_bias);
120
+ #else
121
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
122
+ #endif
123
+ }
124
+ AT_ERROR("modulated deform conv is not implemented on CPU");
125
+ }
126
+
127
+ void modulated_deform_conv_backward(
128
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
129
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
130
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
131
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
132
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
133
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
134
+ const bool with_bias) {
135
+ if (input.device().is_cuda()) {
136
+ #ifdef WITH_CUDA
137
+ return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
138
+ offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
139
+ grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
140
+ pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
141
+ with_bias);
142
+ #else
143
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
144
+ #endif
145
+ }
146
+ AT_ERROR("modulated deform conv is not implemented on CPU");
147
+ }
148
+
149
+
150
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
151
+ m.def("deform_conv_forward", &deform_conv_forward,
152
+ "deform forward");
153
+ m.def("deform_conv_backward_input", &deform_conv_backward_input,
154
+ "deform_conv_backward_input");
155
+ m.def("deform_conv_backward_parameters",
156
+ &deform_conv_backward_parameters,
157
+ "deform_conv_backward_parameters");
158
+ m.def("modulated_deform_conv_forward",
159
+ &modulated_deform_conv_forward,
160
+ "modulated deform conv forward");
161
+ m.def("modulated_deform_conv_backward",
162
+ &modulated_deform_conv_backward,
163
+ "modulated deform conv backward");
164
+ }
CodeFormer/basicsr/ops/fused_act/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+
3
+ __all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
CodeFormer/basicsr/ops/fused_act/fused_act.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+
7
+ try:
8
+ from . import fused_act_ext
9
+ except ImportError:
10
+ import os
11
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
12
+ if BASICSR_JIT == 'True':
13
+ from torch.utils.cpp_extension import load
14
+ module_path = os.path.dirname(__file__)
15
+ fused_act_ext = load(
16
+ 'fused',
17
+ sources=[
18
+ os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
19
+ os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
20
+ ],
21
+ )
22
+
23
+
24
+ class FusedLeakyReLUFunctionBackward(Function):
25
+
26
+ @staticmethod
27
+ def forward(ctx, grad_output, out, negative_slope, scale):
28
+ ctx.save_for_backward(out)
29
+ ctx.negative_slope = negative_slope
30
+ ctx.scale = scale
31
+
32
+ empty = grad_output.new_empty(0)
33
+
34
+ grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
35
+
36
+ dim = [0]
37
+
38
+ if grad_input.ndim > 2:
39
+ dim += list(range(2, grad_input.ndim))
40
+
41
+ grad_bias = grad_input.sum(dim).detach()
42
+
43
+ return grad_input, grad_bias
44
+
45
+ @staticmethod
46
+ def backward(ctx, gradgrad_input, gradgrad_bias):
47
+ out, = ctx.saved_tensors
48
+ gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
49
+ ctx.scale)
50
+
51
+ return gradgrad_out, None, None, None
52
+
53
+
54
+ class FusedLeakyReLUFunction(Function):
55
+
56
+ @staticmethod
57
+ def forward(ctx, input, bias, negative_slope, scale):
58
+ empty = input.new_empty(0)
59
+ out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
60
+ ctx.save_for_backward(out)
61
+ ctx.negative_slope = negative_slope
62
+ ctx.scale = scale
63
+
64
+ return out
65
+
66
+ @staticmethod
67
+ def backward(ctx, grad_output):
68
+ out, = ctx.saved_tensors
69
+
70
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
71
+
72
+ return grad_input, grad_bias, None, None
73
+
74
+
75
+ class FusedLeakyReLU(nn.Module):
76
+
77
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
78
+ super().__init__()
79
+
80
+ self.bias = nn.Parameter(torch.zeros(channel))
81
+ self.negative_slope = negative_slope
82
+ self.scale = scale
83
+
84
+ def forward(self, input):
85
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
86
+
87
+
88
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
89
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
CodeFormer/basicsr/ops/fused_act/src/fused_bias_act.cpp ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
2
+ #include <torch/extension.h>
3
+
4
+
5
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input,
6
+ const torch::Tensor& bias,
7
+ const torch::Tensor& refer,
8
+ int act, int grad, float alpha, float scale);
9
+
10
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
12
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13
+
14
+ torch::Tensor fused_bias_act(const torch::Tensor& input,
15
+ const torch::Tensor& bias,
16
+ const torch::Tensor& refer,
17
+ int act, int grad, float alpha, float scale) {
18
+ CHECK_CUDA(input);
19
+ CHECK_CUDA(bias);
20
+
21
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
22
+ }
23
+
24
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
25
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
26
+ }
CodeFormer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
2
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
3
+ //
4
+ // This work is made available under the Nvidia Source Code License-NC.
5
+ // To view a copy of this license, visit
6
+ // https://nvlabs.github.io/stylegan2/license.html
7
+
8
+ #include <torch/types.h>
9
+
10
+ #include <ATen/ATen.h>
11
+ #include <ATen/AccumulateType.h>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+
19
+ template <typename scalar_t>
20
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
21
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
22
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
23
+
24
+ scalar_t zero = 0.0;
25
+
26
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
27
+ scalar_t x = p_x[xi];
28
+
29
+ if (use_bias) {
30
+ x += p_b[(xi / step_b) % size_b];
31
+ }
32
+
33
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
34
+
35
+ scalar_t y;
36
+
37
+ switch (act * 10 + grad) {
38
+ default:
39
+ case 10: y = x; break;
40
+ case 11: y = x; break;
41
+ case 12: y = 0.0; break;
42
+
43
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
44
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
45
+ case 32: y = 0.0; break;
46
+ }
47
+
48
+ out[xi] = y * scale;
49
+ }
50
+ }
51
+
52
+
53
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
54
+ int act, int grad, float alpha, float scale) {
55
+ int curDevice = -1;
56
+ cudaGetDevice(&curDevice);
57
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
58
+
59
+ auto x = input.contiguous();
60
+ auto b = bias.contiguous();
61
+ auto ref = refer.contiguous();
62
+
63
+ int use_bias = b.numel() ? 1 : 0;
64
+ int use_ref = ref.numel() ? 1 : 0;
65
+
66
+ int size_x = x.numel();
67
+ int size_b = b.numel();
68
+ int step_b = 1;
69
+
70
+ for (int i = 1 + 1; i < x.dim(); i++) {
71
+ step_b *= x.size(i);
72
+ }
73
+
74
+ int loop_x = 4;
75
+ int block_size = 4 * 32;
76
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
77
+
78
+ auto y = torch::empty_like(x);
79
+
80
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
81
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
82
+ y.data_ptr<scalar_t>(),
83
+ x.data_ptr<scalar_t>(),
84
+ b.data_ptr<scalar_t>(),
85
+ ref.data_ptr<scalar_t>(),
86
+ act,
87
+ grad,
88
+ alpha,
89
+ scale,
90
+ loop_x,
91
+ size_x,
92
+ step_b,
93
+ size_b,
94
+ use_bias,
95
+ use_ref
96
+ );
97
+ });
98
+
99
+ return y;
100
+ }
CodeFormer/basicsr/ops/upfirdn2d/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .upfirdn2d import upfirdn2d
2
+
3
+ __all__ = ['upfirdn2d']
CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
2
+ #include <torch/extension.h>
3
+
4
+
5
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
6
+ int up_x, int up_y, int down_x, int down_y,
7
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
8
+
9
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
10
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
11
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
12
+
13
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
14
+ int up_x, int up_y, int down_x, int down_y,
15
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
16
+ CHECK_CUDA(input);
17
+ CHECK_CUDA(kernel);
18
+
19
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
20
+ }
21
+
22
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
23
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
24
+ }
CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
2
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
3
+ //
4
+ // This work is made available under the Nvidia Source Code License-NC.
5
+ // To view a copy of this license, visit
6
+ // https://nvlabs.github.io/stylegan2/license.html
7
+
8
+ #include <torch/types.h>
9
+
10
+ #include <ATen/ATen.h>
11
+ #include <ATen/AccumulateType.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+ #include <ATen/cuda/CUDAContext.h>
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
+ int c = a / b;
20
+
21
+ if (c * b > a) {
22
+ c--;
23
+ }
24
+
25
+ return c;
26
+ }
27
+
28
+ struct UpFirDn2DKernelParams {
29
+ int up_x;
30
+ int up_y;
31
+ int down_x;
32
+ int down_y;
33
+ int pad_x0;
34
+ int pad_x1;
35
+ int pad_y0;
36
+ int pad_y1;
37
+
38
+ int major_dim;
39
+ int in_h;
40
+ int in_w;
41
+ int minor_dim;
42
+ int kernel_h;
43
+ int kernel_w;
44
+ int out_h;
45
+ int out_w;
46
+ int loop_major;
47
+ int loop_x;
48
+ };
49
+
50
+ template <typename scalar_t>
51
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
52
+ const scalar_t *kernel,
53
+ const UpFirDn2DKernelParams p) {
54
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
55
+ int out_y = minor_idx / p.minor_dim;
56
+ minor_idx -= out_y * p.minor_dim;
57
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
58
+ int major_idx_base = blockIdx.z * p.loop_major;
59
+
60
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
61
+ major_idx_base >= p.major_dim) {
62
+ return;
63
+ }
64
+
65
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
66
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
67
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
68
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
69
+
70
+ for (int loop_major = 0, major_idx = major_idx_base;
71
+ loop_major < p.loop_major && major_idx < p.major_dim;
72
+ loop_major++, major_idx++) {
73
+ for (int loop_x = 0, out_x = out_x_base;
74
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
75
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
76
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
77
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
78
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
79
+
80
+ const scalar_t *x_p =
81
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
82
+ minor_idx];
83
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
84
+ int x_px = p.minor_dim;
85
+ int k_px = -p.up_x;
86
+ int x_py = p.in_w * p.minor_dim;
87
+ int k_py = -p.up_y * p.kernel_w;
88
+
89
+ scalar_t v = 0.0f;
90
+
91
+ for (int y = 0; y < h; y++) {
92
+ for (int x = 0; x < w; x++) {
93
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
94
+ x_p += x_px;
95
+ k_p += k_px;
96
+ }
97
+
98
+ x_p += x_py - w * x_px;
99
+ k_p += k_py - w * k_px;
100
+ }
101
+
102
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
103
+ minor_idx] = v;
104
+ }
105
+ }
106
+ }
107
+
108
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
109
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
110
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
111
+ const scalar_t *kernel,
112
+ const UpFirDn2DKernelParams p) {
113
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
114
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
115
+
116
+ __shared__ volatile float sk[kernel_h][kernel_w];
117
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
118
+
119
+ int minor_idx = blockIdx.x;
120
+ int tile_out_y = minor_idx / p.minor_dim;
121
+ minor_idx -= tile_out_y * p.minor_dim;
122
+ tile_out_y *= tile_out_h;
123
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
124
+ int major_idx_base = blockIdx.z * p.loop_major;
125
+
126
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
127
+ major_idx_base >= p.major_dim) {
128
+ return;
129
+ }
130
+
131
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
132
+ tap_idx += blockDim.x) {
133
+ int ky = tap_idx / kernel_w;
134
+ int kx = tap_idx - ky * kernel_w;
135
+ scalar_t v = 0.0;
136
+
137
+ if (kx < p.kernel_w & ky < p.kernel_h) {
138
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
139
+ }
140
+
141
+ sk[ky][kx] = v;
142
+ }
143
+
144
+ for (int loop_major = 0, major_idx = major_idx_base;
145
+ loop_major < p.loop_major & major_idx < p.major_dim;
146
+ loop_major++, major_idx++) {
147
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
148
+ loop_x < p.loop_x & tile_out_x < p.out_w;
149
+ loop_x++, tile_out_x += tile_out_w) {
150
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
151
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
152
+ int tile_in_x = floor_div(tile_mid_x, up_x);
153
+ int tile_in_y = floor_div(tile_mid_y, up_y);
154
+
155
+ __syncthreads();
156
+
157
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
158
+ in_idx += blockDim.x) {
159
+ int rel_in_y = in_idx / tile_in_w;
160
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
161
+ int in_x = rel_in_x + tile_in_x;
162
+ int in_y = rel_in_y + tile_in_y;
163
+
164
+ scalar_t v = 0.0;
165
+
166
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
167
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
168
+ p.minor_dim +
169
+ minor_idx];
170
+ }
171
+
172
+ sx[rel_in_y][rel_in_x] = v;
173
+ }
174
+
175
+ __syncthreads();
176
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
177
+ out_idx += blockDim.x) {
178
+ int rel_out_y = out_idx / tile_out_w;
179
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
180
+ int out_x = rel_out_x + tile_out_x;
181
+ int out_y = rel_out_y + tile_out_y;
182
+
183
+ int mid_x = tile_mid_x + rel_out_x * down_x;
184
+ int mid_y = tile_mid_y + rel_out_y * down_y;
185
+ int in_x = floor_div(mid_x, up_x);
186
+ int in_y = floor_div(mid_y, up_y);
187
+ int rel_in_x = in_x - tile_in_x;
188
+ int rel_in_y = in_y - tile_in_y;
189
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
190
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
191
+
192
+ scalar_t v = 0.0;
193
+
194
+ #pragma unroll
195
+ for (int y = 0; y < kernel_h / up_y; y++)
196
+ #pragma unroll
197
+ for (int x = 0; x < kernel_w / up_x; x++)
198
+ v += sx[rel_in_y + y][rel_in_x + x] *
199
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
200
+
201
+ if (out_x < p.out_w & out_y < p.out_h) {
202
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
203
+ minor_idx] = v;
204
+ }
205
+ }
206
+ }
207
+ }
208
+ }
209
+
210
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
211
+ const torch::Tensor &kernel, int up_x, int up_y,
212
+ int down_x, int down_y, int pad_x0, int pad_x1,
213
+ int pad_y0, int pad_y1) {
214
+ int curDevice = -1;
215
+ cudaGetDevice(&curDevice);
216
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
217
+
218
+ UpFirDn2DKernelParams p;
219
+
220
+ auto x = input.contiguous();
221
+ auto k = kernel.contiguous();
222
+
223
+ p.major_dim = x.size(0);
224
+ p.in_h = x.size(1);
225
+ p.in_w = x.size(2);
226
+ p.minor_dim = x.size(3);
227
+ p.kernel_h = k.size(0);
228
+ p.kernel_w = k.size(1);
229
+ p.up_x = up_x;
230
+ p.up_y = up_y;
231
+ p.down_x = down_x;
232
+ p.down_y = down_y;
233
+ p.pad_x0 = pad_x0;
234
+ p.pad_x1 = pad_x1;
235
+ p.pad_y0 = pad_y0;
236
+ p.pad_y1 = pad_y1;
237
+
238
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
239
+ p.down_y;
240
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
241
+ p.down_x;
242
+
243
+ auto out =
244
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
245
+
246
+ int mode = -1;
247
+
248
+ int tile_out_h = -1;
249
+ int tile_out_w = -1;
250
+
251
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
252
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
253
+ mode = 1;
254
+ tile_out_h = 16;
255
+ tile_out_w = 64;
256
+ }
257
+
258
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
259
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
260
+ mode = 2;
261
+ tile_out_h = 16;
262
+ tile_out_w = 64;
263
+ }
264
+
265
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
266
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
267
+ mode = 3;
268
+ tile_out_h = 16;
269
+ tile_out_w = 64;
270
+ }
271
+
272
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
273
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
274
+ mode = 4;
275
+ tile_out_h = 16;
276
+ tile_out_w = 64;
277
+ }
278
+
279
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
280
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
281
+ mode = 5;
282
+ tile_out_h = 8;
283
+ tile_out_w = 32;
284
+ }
285
+
286
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
287
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
288
+ mode = 6;
289
+ tile_out_h = 8;
290
+ tile_out_w = 32;
291
+ }
292
+
293
+ dim3 block_size;
294
+ dim3 grid_size;
295
+
296
+ if (tile_out_h > 0 && tile_out_w > 0) {
297
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
298
+ p.loop_x = 1;
299
+ block_size = dim3(32 * 8, 1, 1);
300
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
301
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
302
+ (p.major_dim - 1) / p.loop_major + 1);
303
+ } else {
304
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
305
+ p.loop_x = 4;
306
+ block_size = dim3(4, 32, 1);
307
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
308
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
309
+ (p.major_dim - 1) / p.loop_major + 1);
310
+ }
311
+
312
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
313
+ switch (mode) {
314
+ case 1:
315
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
316
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
317
+ x.data_ptr<scalar_t>(),
318
+ k.data_ptr<scalar_t>(), p);
319
+
320
+ break;
321
+
322
+ case 2:
323
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
324
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
325
+ x.data_ptr<scalar_t>(),
326
+ k.data_ptr<scalar_t>(), p);
327
+
328
+ break;
329
+
330
+ case 3:
331
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
332
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
333
+ x.data_ptr<scalar_t>(),
334
+ k.data_ptr<scalar_t>(), p);
335
+
336
+ break;
337
+
338
+ case 4:
339
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
340
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
341
+ x.data_ptr<scalar_t>(),
342
+ k.data_ptr<scalar_t>(), p);
343
+
344
+ break;
345
+
346
+ case 5:
347
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
348
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
349
+ x.data_ptr<scalar_t>(),
350
+ k.data_ptr<scalar_t>(), p);
351
+
352
+ break;
353
+
354
+ case 6:
355
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
356
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
357
+ x.data_ptr<scalar_t>(),
358
+ k.data_ptr<scalar_t>(), p);
359
+
360
+ break;
361
+
362
+ default:
363
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
364
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
365
+ k.data_ptr<scalar_t>(), p);
366
+ }
367
+ });
368
+
369
+ return out;
370
+ }
CodeFormer/basicsr/ops/upfirdn2d/upfirdn2d.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.nn import functional as F
6
+
7
+ try:
8
+ from . import upfirdn2d_ext
9
+ except ImportError:
10
+ import os
11
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
12
+ if BASICSR_JIT == 'True':
13
+ from torch.utils.cpp_extension import load
14
+ module_path = os.path.dirname(__file__)
15
+ upfirdn2d_ext = load(
16
+ 'upfirdn2d',
17
+ sources=[
18
+ os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
19
+ os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
20
+ ],
21
+ )
22
+
23
+
24
+ class UpFirDn2dBackward(Function):
25
+
26
+ @staticmethod
27
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
28
+
29
+ up_x, up_y = up
30
+ down_x, down_y = down
31
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
32
+
33
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
34
+
35
+ grad_input = upfirdn2d_ext.upfirdn2d(
36
+ grad_output,
37
+ grad_kernel,
38
+ down_x,
39
+ down_y,
40
+ up_x,
41
+ up_y,
42
+ g_pad_x0,
43
+ g_pad_x1,
44
+ g_pad_y0,
45
+ g_pad_y1,
46
+ )
47
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
48
+
49
+ ctx.save_for_backward(kernel)
50
+
51
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
52
+
53
+ ctx.up_x = up_x
54
+ ctx.up_y = up_y
55
+ ctx.down_x = down_x
56
+ ctx.down_y = down_y
57
+ ctx.pad_x0 = pad_x0
58
+ ctx.pad_x1 = pad_x1
59
+ ctx.pad_y0 = pad_y0
60
+ ctx.pad_y1 = pad_y1
61
+ ctx.in_size = in_size
62
+ ctx.out_size = out_size
63
+
64
+ return grad_input
65
+
66
+ @staticmethod
67
+ def backward(ctx, gradgrad_input):
68
+ kernel, = ctx.saved_tensors
69
+
70
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
71
+
72
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
73
+ gradgrad_input,
74
+ kernel,
75
+ ctx.up_x,
76
+ ctx.up_y,
77
+ ctx.down_x,
78
+ ctx.down_y,
79
+ ctx.pad_x0,
80
+ ctx.pad_x1,
81
+ ctx.pad_y0,
82
+ ctx.pad_y1,
83
+ )
84
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
85
+ # ctx.out_size[1], ctx.in_size[3])
86
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
87
+
88
+ return gradgrad_out, None, None, None, None, None, None, None, None
89
+
90
+
91
+ class UpFirDn2d(Function):
92
+
93
+ @staticmethod
94
+ def forward(ctx, input, kernel, up, down, pad):
95
+ up_x, up_y = up
96
+ down_x, down_y = down
97
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
98
+
99
+ kernel_h, kernel_w = kernel.shape
100
+ batch, channel, in_h, in_w = input.shape
101
+ ctx.in_size = input.shape
102
+
103
+ input = input.reshape(-1, in_h, in_w, 1)
104
+
105
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
106
+
107
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
108
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
109
+ ctx.out_size = (out_h, out_w)
110
+
111
+ ctx.up = (up_x, up_y)
112
+ ctx.down = (down_x, down_y)
113
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
114
+
115
+ g_pad_x0 = kernel_w - pad_x0 - 1
116
+ g_pad_y0 = kernel_h - pad_y0 - 1
117
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
118
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
119
+
120
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
121
+
122
+ out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
123
+ # out = out.view(major, out_h, out_w, minor)
124
+ out = out.view(-1, channel, out_h, out_w)
125
+
126
+ return out
127
+
128
+ @staticmethod
129
+ def backward(ctx, grad_output):
130
+ kernel, grad_kernel = ctx.saved_tensors
131
+
132
+ grad_input = UpFirDn2dBackward.apply(
133
+ grad_output,
134
+ kernel,
135
+ grad_kernel,
136
+ ctx.up,
137
+ ctx.down,
138
+ ctx.pad,
139
+ ctx.g_pad,
140
+ ctx.in_size,
141
+ ctx.out_size,
142
+ )
143
+
144
+ return grad_input, None, None, None, None
145
+
146
+
147
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
148
+ if input.device.type == 'cpu':
149
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
150
+ else:
151
+ out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
152
+
153
+ return out
154
+
155
+
156
+ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
157
+ _, channel, in_h, in_w = input.shape
158
+ input = input.reshape(-1, in_h, in_w, 1)
159
+
160
+ _, in_h, in_w, minor = input.shape
161
+ kernel_h, kernel_w = kernel.shape
162
+
163
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
164
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
165
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
166
+
167
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
168
+ out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
169
+
170
+ out = out.permute(0, 3, 1, 2)
171
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
172
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
173
+ out = F.conv2d(out, w)
174
+ out = out.reshape(
175
+ -1,
176
+ minor,
177
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
178
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
179
+ )
180
+ out = out.permute(0, 2, 3, 1)
181
+ out = out[:, ::down_y, ::down_x, :]
182
+
183
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
184
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
185
+
186
+ return out.view(-1, channel, out_h, out_w)
CodeFormer/basicsr/setup.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import time
9
+ import torch
10
+ from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
11
+
12
+ version_file = './basicsr/version.py'
13
+
14
+
15
+ def readme():
16
+ with open('README.md', encoding='utf-8') as f:
17
+ content = f.read()
18
+ return content
19
+
20
+
21
+ def get_git_hash():
22
+
23
+ def _minimal_ext_cmd(cmd):
24
+ # construct minimal environment
25
+ env = {}
26
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
27
+ v = os.environ.get(k)
28
+ if v is not None:
29
+ env[k] = v
30
+ # LANGUAGE is used on win32
31
+ env['LANGUAGE'] = 'C'
32
+ env['LANG'] = 'C'
33
+ env['LC_ALL'] = 'C'
34
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
35
+ return out
36
+
37
+ try:
38
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
39
+ sha = out.strip().decode('ascii')
40
+ except OSError:
41
+ sha = 'unknown'
42
+
43
+ return sha
44
+
45
+
46
+ def get_hash():
47
+ if os.path.exists('.git'):
48
+ sha = get_git_hash()[:7]
49
+ elif os.path.exists(version_file):
50
+ try:
51
+ from version import __version__
52
+ sha = __version__.split('+')[-1]
53
+ except ImportError:
54
+ raise ImportError('Unable to get git version')
55
+ else:
56
+ sha = 'unknown'
57
+
58
+ return sha
59
+
60
+
61
+ def write_version_py():
62
+ content = """# GENERATED VERSION FILE
63
+ # TIME: {}
64
+ __version__ = '{}'
65
+ __gitsha__ = '{}'
66
+ version_info = ({})
67
+ """
68
+ sha = get_hash()
69
+ with open('./basicsr/VERSION', 'r') as f:
70
+ SHORT_VERSION = f.read().strip()
71
+ VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
72
+
73
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
74
+ with open(version_file, 'w') as f:
75
+ f.write(version_file_str)
76
+
77
+
78
+ def get_version():
79
+ with open(version_file, 'r') as f:
80
+ exec(compile(f.read(), version_file, 'exec'))
81
+ return locals()['__version__']
82
+
83
+
84
+ def make_cuda_ext(name, module, sources, sources_cuda=None):
85
+ if sources_cuda is None:
86
+ sources_cuda = []
87
+ define_macros = []
88
+ extra_compile_args = {'cxx': []}
89
+
90
+ if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
91
+ define_macros += [('WITH_CUDA', None)]
92
+ extension = CUDAExtension
93
+ extra_compile_args['nvcc'] = [
94
+ '-D__CUDA_NO_HALF_OPERATORS__',
95
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
96
+ '-D__CUDA_NO_HALF2_OPERATORS__',
97
+ ]
98
+ sources += sources_cuda
99
+ else:
100
+ print(f'Compiling {name} without CUDA')
101
+ extension = CppExtension
102
+
103
+ return extension(
104
+ name=f'{module}.{name}',
105
+ sources=[os.path.join(*module.split('.'), p) for p in sources],
106
+ define_macros=define_macros,
107
+ extra_compile_args=extra_compile_args)
108
+
109
+
110
+ def get_requirements(filename='requirements.txt'):
111
+ with open(os.path.join('.', filename), 'r') as f:
112
+ requires = [line.replace('\n', '') for line in f.readlines()]
113
+ return requires
114
+
115
+
116
+ if __name__ == '__main__':
117
+ if '--cuda_ext' in sys.argv:
118
+ ext_modules = [
119
+ make_cuda_ext(
120
+ name='deform_conv_ext',
121
+ module='ops.dcn',
122
+ sources=['src/deform_conv_ext.cpp'],
123
+ sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
124
+ make_cuda_ext(
125
+ name='fused_act_ext',
126
+ module='ops.fused_act',
127
+ sources=['src/fused_bias_act.cpp'],
128
+ sources_cuda=['src/fused_bias_act_kernel.cu']),
129
+ make_cuda_ext(
130
+ name='upfirdn2d_ext',
131
+ module='ops.upfirdn2d',
132
+ sources=['src/upfirdn2d.cpp'],
133
+ sources_cuda=['src/upfirdn2d_kernel.cu']),
134
+ ]
135
+ sys.argv.remove('--cuda_ext')
136
+ else:
137
+ ext_modules = []
138
+
139
+ write_version_py()
140
+ setup(
141
+ name='basicsr',
142
+ version=get_version(),
143
+ description='Open Source Image and Video Super-Resolution Toolbox',
144
+ long_description=readme(),
145
+ long_description_content_type='text/markdown',
146
+ author='Xintao Wang',
147
+ author_email='xintao.wang@outlook.com',
148
+ keywords='computer vision, restoration, super resolution',
149
+ url='https://github.com/xinntao/BasicSR',
150
+ include_package_data=True,
151
+ packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
152
+ classifiers=[
153
+ 'Development Status :: 4 - Beta',
154
+ 'License :: OSI Approved :: Apache Software License',
155
+ 'Operating System :: OS Independent',
156
+ 'Programming Language :: Python :: 3',
157
+ 'Programming Language :: Python :: 3.7',
158
+ 'Programming Language :: Python :: 3.8',
159
+ ],
160
+ license='Apache License 2.0',
161
+ setup_requires=['cython', 'numpy'],
162
+ install_requires=get_requirements(),
163
+ ext_modules=ext_modules,
164
+ cmdclass={'build_ext': BuildExtension},
165
+ zip_safe=False)
CodeFormer/basicsr/train.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import math
5
+ import copy
6
+ import random
7
+ import time
8
+ import torch
9
+ from os import path as osp
10
+
11
+ from basicsr.data import build_dataloader, build_dataset
12
+ from basicsr.data.data_sampler import EnlargedSampler
13
+ from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
14
+ from basicsr.models import build_model
15
+ from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
16
+ init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed)
17
+ from basicsr.utils.dist_util import get_dist_info, init_dist
18
+ from basicsr.utils.options import dict2str, parse
19
+
20
+ import warnings
21
+ # ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
22
+ warnings.filterwarnings("ignore", category=UserWarning)
23
+
24
+ def parse_options(root_path, is_train=True):
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
27
+ parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
28
+ parser.add_argument('--local_rank', type=int, default=0)
29
+ args = parser.parse_args()
30
+ opt = parse(args.opt, root_path, is_train=is_train)
31
+
32
+ # distributed settings
33
+ if args.launcher == 'none':
34
+ opt['dist'] = False
35
+ print('Disable distributed.', flush=True)
36
+ else:
37
+ opt['dist'] = True
38
+ if args.launcher == 'slurm' and 'dist_params' in opt:
39
+ init_dist(args.launcher, **opt['dist_params'])
40
+ else:
41
+ init_dist(args.launcher)
42
+
43
+ opt['rank'], opt['world_size'] = get_dist_info()
44
+
45
+ # random seed
46
+ seed = opt.get('manual_seed')
47
+ if seed is None:
48
+ seed = random.randint(1, 10000)
49
+ opt['manual_seed'] = seed
50
+ set_random_seed(seed + opt['rank'])
51
+
52
+ return opt
53
+
54
+
55
+ def init_loggers(opt):
56
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
57
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
58
+ logger.info(get_env_info())
59
+ logger.info(dict2str(opt))
60
+
61
+ # initialize wandb logger before tensorboard logger to allow proper sync:
62
+ if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
63
+ assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
64
+ init_wandb_logger(opt)
65
+ tb_logger = None
66
+ if opt['logger'].get('use_tb_logger'):
67
+ tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
68
+ return logger, tb_logger
69
+
70
+
71
+ def create_train_val_dataloader(opt, logger):
72
+ # create train and val dataloaders
73
+ train_loader, val_loader = None, None
74
+ for phase, dataset_opt in opt['datasets'].items():
75
+ if phase == 'train':
76
+ dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
77
+ train_set = build_dataset(dataset_opt)
78
+ train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
79
+ train_loader = build_dataloader(
80
+ train_set,
81
+ dataset_opt,
82
+ num_gpu=opt['num_gpu'],
83
+ dist=opt['dist'],
84
+ sampler=train_sampler,
85
+ seed=opt['manual_seed'])
86
+
87
+ num_iter_per_epoch = math.ceil(
88
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
89
+ total_iters = int(opt['train']['total_iter'])
90
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
91
+ logger.info('Training statistics:'
92
+ f'\n\tNumber of train images: {len(train_set)}'
93
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
94
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
95
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
96
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
97
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
98
+
99
+ elif phase == 'val':
100
+ val_set = build_dataset(dataset_opt)
101
+ val_loader = build_dataloader(
102
+ val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
103
+ logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
104
+ else:
105
+ raise ValueError(f'Dataset phase {phase} is not recognized.')
106
+
107
+ return train_loader, train_sampler, val_loader, total_epochs, total_iters
108
+
109
+
110
+ def train_pipeline(root_path):
111
+ # parse options, set distributed setting, set ramdom seed
112
+ opt = parse_options(root_path, is_train=True)
113
+
114
+ torch.backends.cudnn.benchmark = True
115
+ # torch.backends.cudnn.deterministic = True
116
+
117
+ # load resume states if necessary
118
+ if opt['path'].get('resume_state'):
119
+ device_id = torch.cuda.current_device()
120
+ resume_state = torch.load(
121
+ opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
122
+ else:
123
+ resume_state = None
124
+
125
+ # mkdir for experiments and logger
126
+ if resume_state is None:
127
+ make_exp_dirs(opt)
128
+ if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
129
+ mkdir_and_rename(osp.join('tb_logger', opt['name']))
130
+
131
+ # initialize loggers
132
+ logger, tb_logger = init_loggers(opt)
133
+
134
+ # create train and validation dataloaders
135
+ result = create_train_val_dataloader(opt, logger)
136
+ train_loader, train_sampler, val_loader, total_epochs, total_iters = result
137
+
138
+ # create model
139
+ if resume_state: # resume training
140
+ check_resume(opt, resume_state['iter'])
141
+ model = build_model(opt)
142
+ model.resume_training(resume_state) # handle optimizers and schedulers
143
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
144
+ start_epoch = resume_state['epoch']
145
+ current_iter = resume_state['iter']
146
+ else:
147
+ model = build_model(opt)
148
+ start_epoch = 0
149
+ current_iter = 0
150
+
151
+ # create message logger (formatted outputs)
152
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
153
+
154
+ # dataloader prefetcher
155
+ prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
156
+ if prefetch_mode is None or prefetch_mode == 'cpu':
157
+ prefetcher = CPUPrefetcher(train_loader)
158
+ elif prefetch_mode == 'cuda':
159
+ prefetcher = CUDAPrefetcher(train_loader, opt)
160
+ logger.info(f'Use {prefetch_mode} prefetch dataloader')
161
+ if opt['datasets']['train'].get('pin_memory') is not True:
162
+ raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
163
+ else:
164
+ raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
165
+
166
+ # training
167
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
168
+ data_time, iter_time = time.time(), time.time()
169
+ start_time = time.time()
170
+
171
+ for epoch in range(start_epoch, total_epochs + 1):
172
+ train_sampler.set_epoch(epoch)
173
+ prefetcher.reset()
174
+ train_data = prefetcher.next()
175
+
176
+ while train_data is not None:
177
+ data_time = time.time() - data_time
178
+
179
+ current_iter += 1
180
+ if current_iter > total_iters:
181
+ break
182
+ # update learning rate
183
+ model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
184
+ # training
185
+ model.feed_data(train_data)
186
+ model.optimize_parameters(current_iter)
187
+ iter_time = time.time() - iter_time
188
+ # log
189
+ if current_iter % opt['logger']['print_freq'] == 0:
190
+ log_vars = {'epoch': epoch, 'iter': current_iter}
191
+ log_vars.update({'lrs': model.get_current_learning_rate()})
192
+ log_vars.update({'time': iter_time, 'data_time': data_time})
193
+ log_vars.update(model.get_current_log())
194
+ msg_logger(log_vars)
195
+
196
+ # save models and training states
197
+ if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
198
+ logger.info('Saving models and training states.')
199
+ model.save(epoch, current_iter)
200
+
201
+ # validation
202
+ if opt.get('val') is not None and opt['datasets'].get('val') is not None \
203
+ and (current_iter % opt['val']['val_freq'] == 0):
204
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
205
+
206
+ data_time = time.time()
207
+ iter_time = time.time()
208
+ train_data = prefetcher.next()
209
+ # end of iter
210
+
211
+ # end of epoch
212
+
213
+ consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
214
+ logger.info(f'End of training. Time consumed: {consumed_time}')
215
+ logger.info('Save the latest model.')
216
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
217
+ if opt.get('val') is not None and opt['datasets'].get('val'):
218
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
219
+ if tb_logger:
220
+ tb_logger.close()
221
+
222
+
223
+ if __name__ == '__main__':
224
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
225
+ train_pipeline(root_path)