befozg commited on
Commit
f0de4e8
·
1 Parent(s): 9967c2f

added initial portrait transfer app

Browse files
Files changed (10) hide show
  1. .gitignore +174 -0
  2. app.py +107 -0
  3. requirements.txt +38 -0
  4. slider.html +137 -0
  5. tools/__init__.py +3 -0
  6. tools/inference.py +56 -0
  7. tools/model.py +296 -0
  8. tools/normalizer.py +261 -0
  9. tools/stylematte.py +506 -0
  10. tools/util.py +345 -0
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+
163
+ config/*
164
+ trainer/__pycache__/
165
+ trainer/__pycache__/*
166
+ __pycache__/*
167
+ checkpoints/*.pth
168
+ */*.pth
169
+ */checkpoints/best_pure.pth
170
+ checkpoints/best_pure.pth
171
+ *.ipynb
172
+ .ipynb_checkpoints/*
173
+ flagged/
174
+ assets/
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from tools import Inference, Matting, log
3
+ from omegaconf import OmegaConf
4
+ import os
5
+ import sys
6
+ import numpy as np
7
+ import torchvision.transforms.functional as tf
8
+ from PIL import Image
9
+
10
+ args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
11
+
12
+ global_comp = None
13
+ global_mask = None
14
+
15
+ log("Model loading")
16
+ phnet = Inference(**args)
17
+ stylematte = Matting(**args)
18
+ log("Model loaded")
19
+
20
+
21
+ def harmonize(comp, mask):
22
+ log("Inference started")
23
+ if comp is None or mask is None:
24
+ log("Empty source")
25
+ return np.zeros((16, 16, 3))
26
+
27
+ comp = comp.convert('RGB')
28
+ mask = mask.convert('1')
29
+ in_shape = comp.size[::-1]
30
+
31
+ comp = tf.resize(comp, [args.image_size, args.image_size])
32
+ mask = tf.resize(mask, [args.image_size, args.image_size])
33
+
34
+ compt = tf.to_tensor(comp)
35
+ maskt = tf.to_tensor(mask)
36
+ res = phnet.harmonize(compt, maskt)
37
+ res = tf.resize(res, in_shape)
38
+
39
+ log("Inference finished")
40
+
41
+ return np.uint8((res*255)[0].permute(1, 2, 0).numpy())
42
+
43
+
44
+ def extract_matte(img, back):
45
+ mask, fg = stylematte.extract(img)
46
+ fg_pil = Image.fromarray(np.uint8(fg))
47
+
48
+ composite = fg + (1 - mask[:, :, None]) * \
49
+ np.array(back.resize(mask.shape[::-1]))
50
+ composite_pil = Image.fromarray(np.uint8(composite))
51
+
52
+ global_comp = composite_pil
53
+ global_mask = mask
54
+
55
+ return [composite_pil, mask, fg_pil]
56
+
57
+
58
+ def css(height=3, scale=2):
59
+ return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}"
60
+
61
+
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown(
64
+ """
65
+ # Welcome to portrait transfer demo app!
66
+ Select source portrait image and new background.
67
+ """)
68
+ btn_compose = gr.Button(value="Compose")
69
+
70
+ with gr.Row():
71
+ input_ui = gr.Image(
72
+ type="numpy", label='Source image to extract foreground')
73
+ back_ui = gr.Image(type="pil", label='The new background')
74
+
75
+ gr.Examples(
76
+ examples=[["./assets/comp.jpg", "./assets/back.jpg"]],
77
+ inputs=[input_ui, back_ui],
78
+ )
79
+
80
+ gr.Markdown(
81
+ """
82
+ ## Resulting alpha matte and extracted foreground.
83
+ """)
84
+ with gr.Row():
85
+ matte_ui = gr.Image(type="pil", label='Alpha matte')
86
+ fg_ui = gr.Image(type="pil", image_mode='RGBA',
87
+ label='Extracted foreground')
88
+
89
+ gr.Markdown(
90
+ """
91
+ ## Click the button and compare the composite with the harmonized version.
92
+ """)
93
+ btn_harmonize = gr.Button(value="Harmonize composite")
94
+
95
+ with gr.Row():
96
+ composite_ui = gr.Image(type="pil", label='Composite')
97
+ harmonized_ui = gr.Image(
98
+ type="pil", label='Harmonized composite', css=css(3, 3))
99
+
100
+ btn_compose.click(extract_matte, inputs=[input_ui, back_ui], outputs=[
101
+ composite_ui, matte_ui, fg_ui])
102
+ btn_harmonize.click(harmonize, inputs=[
103
+ composite_ui, matte_ui], outputs=[harmonized_ui])
104
+
105
+
106
+ log("Interface created")
107
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.30.0
2
+ gradio_client==0.2.4
3
+ huggingface-hub==0.14.1
4
+ imageio==2.25.1
5
+ imgcat==0.5.0
6
+ ipykernel==6.16.0
7
+ ipython==8.5.0
8
+ ipywidgets==8.0.2
9
+ kiwisolver==1.4.2
10
+ kornia==0.6.9
11
+ legacy==0.1.6
12
+ numpy==1.21.6
13
+ omegaconf==2.2.3
14
+ opencv-python==4.5.5.62
15
+ opencv-python-headless==4.7.0.68
16
+ packaging==21.3
17
+ pandas==1.4.2
18
+ parso==0.8.3
19
+ Pillow==9.4.0
20
+ protobuf==3.20.1
21
+ Pygments==2.13.0
22
+ PyMatting==1.1.8
23
+ pyparsing==3.0.9
24
+ pyrsistent==0.19.3
25
+ scikit-image==0.19.3
26
+ scikit-learn==1.1.1
27
+ scipy==1.10.0
28
+ seaborn==0.12.2
29
+ sklearn==0.0
30
+ sniffio==1.3.0
31
+ soupsieve==2.4
32
+ timm==0.6.12
33
+ torch==1.11.0
34
+ torchaudio==0.11.0
35
+ torchvision==0.12.0
36
+ tornado==6.2
37
+ tqdm==4.64.1
38
+ transformers==4.28.1
slider.html ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
5
+ <style>
6
+ * {box-sizing: border-box;}
7
+
8
+ .img-comp-container {
9
+ position: relative;
10
+ height: 200px; /*should be the same height as the images*/
11
+ }
12
+
13
+ .img-comp-img {
14
+ position: absolute;
15
+ width: auto;
16
+ height: auto;
17
+ overflow:hidden;
18
+ }
19
+
20
+ .img-comp-img img {
21
+ display:block;
22
+ vertical-align:middle;
23
+ }
24
+
25
+ .img-comp-slider {
26
+ position: absolute;
27
+ z-index:9;
28
+ cursor: ew-resize;
29
+ /*set the appearance of the slider:*/
30
+ width: 40px;
31
+ height: 40px;
32
+ background-color: #2196F3;
33
+ opacity: 0.7;
34
+ border-radius: 50%;
35
+ }
36
+ </style>
37
+ <script>
38
+ function initComparisons() {
39
+ var x, i;
40
+ /*find all elements with an "overlay" class:*/
41
+ x = document.getElementsByClassName("img-comp-overlay");
42
+ for (i = 0; i < x.length; i++) {
43
+ /*once for each "overlay" element:
44
+ pass the "overlay" element as a parameter when executing the compareImages function:*/
45
+ compareImages(x[i]);
46
+ }
47
+ function compareImages(img) {
48
+ var slider, img, clicked = 0, w, h;
49
+ /*get the width and height of the img element*/
50
+ w = img.offsetWidth;
51
+ h = img.offsetHeight;
52
+ /*set the width of the img element to 50%:*/
53
+ img.style.width = (w / 2) + "px";
54
+ /*create slider:*/
55
+ slider = document.createElement("DIV");
56
+ slider.setAttribute("class", "img-comp-slider");
57
+ /*insert slider*/
58
+ img.parentElement.insertBefore(slider, img);
59
+ /*position the slider in the middle:*/
60
+ slider.style.top = (h / 2) - (slider.offsetHeight / 2) + "px";
61
+ slider.style.left = (w / 2) - (slider.offsetWidth / 2) + "px";
62
+ /*execute a function when the mouse button is pressed:*/
63
+ slider.addEventListener("mousedown", slideReady);
64
+ /*and another function when the mouse button is released:*/
65
+ window.addEventListener("mouseup", slideFinish);
66
+ /*or touched (for touch screens:*/
67
+ slider.addEventListener("touchstart", slideReady);
68
+ /*and released (for touch screens:*/
69
+ window.addEventListener("touchend", slideFinish);
70
+ function slideReady(e) {
71
+ /*prevent any other actions that may occur when moving over the image:*/
72
+ e.preventDefault();
73
+ /*the slider is now clicked and ready to move:*/
74
+ clicked = 1;
75
+ /*execute a function when the slider is moved:*/
76
+ window.addEventListener("mousemove", slideMove);
77
+ window.addEventListener("touchmove", slideMove);
78
+ }
79
+ function slideFinish() {
80
+ /*the slider is no longer clicked:*/
81
+ clicked = 0;
82
+ }
83
+ function slideMove(e) {
84
+ var pos;
85
+ /*if the slider is no longer clicked, exit this function:*/
86
+ if (clicked == 0) return false;
87
+ /*get the cursor's x position:*/
88
+ pos = getCursorPos(e)
89
+ /*prevent the slider from being positioned outside the image:*/
90
+ if (pos < 0) pos = 0;
91
+ if (pos > w) pos = w;
92
+ /*execute a function that will resize the overlay image according to the cursor:*/
93
+ slide(pos);
94
+ }
95
+ function getCursorPos(e) {
96
+ var a, x = 0;
97
+ e = (e.changedTouches) ? e.changedTouches[0] : e;
98
+ /*get the x positions of the image:*/
99
+ a = img.getBoundingClientRect();
100
+ /*calculate the cursor's x coordinate, relative to the image:*/
101
+ x = e.pageX - a.left;
102
+ /*consider any page scrolling:*/
103
+ x = x - window.pageXOffset;
104
+ return x;
105
+ }
106
+ function slide(x) {
107
+ /*resize the image:*/
108
+ img.style.width = x + "px";
109
+ /*position the slider:*/
110
+ slider.style.left = img.offsetWidth - (slider.offsetWidth / 2) + "px";
111
+ }
112
+ }
113
+ }
114
+ </script>
115
+ </head>
116
+ <body>
117
+
118
+ <h1>Compare Two Images</h1>
119
+
120
+ <p>Click and slide the blue slider to compare two images:</p>
121
+
122
+ <div class="img-comp-container">
123
+ <div class="img-comp-img">
124
+ <img src="img_snow.jpg" width="300" height="200">
125
+ </div>
126
+ <div class="img-comp-img img-comp-overlay">
127
+ <img src="img_forest.jpg" width="300" height="200">
128
+ </div>
129
+ </div>
130
+
131
+ <script>
132
+ /*Execute a function that will execute an image compare function for each element with the img-comp-overlay class:*/
133
+ initComparisons();
134
+ </script>
135
+
136
+ </body>
137
+ </html>
tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .inference import Inference
2
+ from .inference import Matting
3
+ from .util import log
tools/inference.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .model import PHNet
3
+ import torchvision.transforms.functional as tf
4
+ from .util import inference_img, log
5
+ from .stylematte import StyleMatte
6
+ import numpy as np
7
+
8
+
9
+ class Inference:
10
+ def __init__(self, **kwargs):
11
+ self.rank = 0
12
+ self.__dict__.update(kwargs)
13
+ self.model = PHNet(enc_sizes=self.enc_sizes,
14
+ skips=self.skips,
15
+ grid_count=self.grid_counts,
16
+ init_weights=self.init_weights,
17
+ init_value=self.init_value)
18
+ log(f"checkpoint: {self.checkpoint.harmonizer}")
19
+ state = torch.load(self.checkpoint.harmonizer,
20
+ map_location=self.device)
21
+
22
+ self.model.load_state_dict(state, strict=True)
23
+ self.model.eval()
24
+
25
+ def harmonize(self, composite, mask):
26
+ if len(composite.shape) < 4:
27
+ composite = composite.unsqueeze(0)
28
+ while len(mask.shape) < 4:
29
+ mask = mask.unsqueeze(0)
30
+ composite = tf.resize(composite, [self.image_size, self.image_size])
31
+ mask = tf.resize(mask, [self.image_size, self.image_size])
32
+ log(composite.shape, mask.shape)
33
+ with torch.no_grad():
34
+ harmonized = self.model(composite, mask)['harmonized']
35
+
36
+ result = harmonized * mask + composite * (1-mask)
37
+ print(result.shape)
38
+ return result
39
+
40
+
41
+ class Matting:
42
+ def __init__(self, **kwargs):
43
+ self.rank = 0
44
+ self.__dict__.update(kwargs)
45
+ self.model = StyleMatte().to(self.device)
46
+ log(f"checkpoint: {self.checkpoint.matting}")
47
+ state = torch.load(self.checkpoint.matting, map_location=self.device)
48
+ self.model.load_state_dict(state, strict=True)
49
+ self.model.eval()
50
+
51
+ def extract(self, inp):
52
+ mask = inference_img(self.model, inp, self.device)
53
+ inp_np = np.array(inp)
54
+ fg = mask[:, :, None]*inp_np
55
+
56
+ return [mask, fg]
tools/model.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ # from shtools import shReconstructSignal
3
+ from torchvision import transforms, utils
4
+ # from torchvision.ops import SqueezeExcitation
5
+ from torch.utils.data import Dataset
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ import torch
9
+ import math
10
+ import cv2
11
+ import numpy as np
12
+ from .normalizer import PatchNormalizer, PatchedHarmonizer
13
+ from .util import rgb_to_lab, lab_to_rgb, lab_shift
14
+
15
+ # from shtools import *
16
+ # from color_converters import luv_to_rgb, rgb_to_luv
17
+ # from skimage import io, transform
18
+ '''
19
+ Input (256,512,3)
20
+ '''
21
+
22
+
23
+ def inpaint_bg(comp, mask, dim=[2, 3]):
24
+ """
25
+ inpaint bg for ihd
26
+ Args:
27
+ comp (torch.float): [0:1]
28
+ mask (torch.float): [0:1]
29
+ """
30
+ back = comp * (1-mask) # *255
31
+ sum = torch.sum(back, dim=dim) # (B, C)
32
+ num = torch.sum((1-mask), dim=dim) # (B, C)
33
+ mu = sum / (num)
34
+ mean = mu[:, :, None, None]
35
+ back = back + mask * mean
36
+
37
+ return back
38
+
39
+
40
+ class ConvTransposeUp(nn.Sequential):
41
+ def __init__(self, in_channels, out_channels, kernel_size=4, padding=1, stride=2, activation=None):
42
+ super().__init__(
43
+ nn.ConvTranspose2d(in_channels, out_channels,
44
+ kernel_size=kernel_size, padding=padding, stride=stride),
45
+ activation() if activation is not None else nn.Identity(),
46
+ )
47
+
48
+
49
+ class UpsampleShuffle(nn.Sequential):
50
+ def __init__(self, in_channels, out_channels, activation=True):
51
+ super().__init__(
52
+ nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
53
+ nn.GELU() if activation else nn.Identity(),
54
+ nn.PixelShuffle(2)
55
+ )
56
+
57
+ def reset_parameters(self):
58
+ init_subpixel(self[0].weight)
59
+ nn.init.zeros_(self[0].bias)
60
+
61
+
62
+ class UpsampleResize(nn.Sequential):
63
+ def __init__(self, in_channels, out_channels, out_size=None, activation=None, scale_factor=2., mode='bilinear'):
64
+ super().__init__(
65
+ nn.Upsample(scale_factor=scale_factor, mode=mode) if out_size is None else nn.Upsample(
66
+ out_size, mode=mode),
67
+ nn.ReflectionPad2d(1),
68
+ nn.Conv2d(in_channels, out_channels,
69
+ kernel_size=3, stride=1, padding=0),
70
+ activation() if activation is not None else nn.Identity(),
71
+
72
+ )
73
+
74
+
75
+ def conv_bn(in_, out_, kernel_size=3, stride=1, padding=1, activation=nn.ReLU, normalization=nn.InstanceNorm2d):
76
+
77
+ return nn.Sequential(
78
+ nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=padding),
79
+ normalization(out_) if normalization is not None else nn.Identity(),
80
+ activation(),
81
+ )
82
+
83
+
84
+ def init_subpixel(weight):
85
+ co, ci, h, w = weight.shape
86
+ co2 = co // 4
87
+ # initialize sub kernel
88
+ k = torch.empty([c02, ci, h, w])
89
+ nn.init.kaiming_uniform_(k)
90
+ # repeat 4 times
91
+ k = k.repeat_interleave(4, dim=0)
92
+ weight.data.copy_(k)
93
+
94
+
95
+ class DownsampleShuffle(nn.Sequential):
96
+ def __init__(self, in_channels):
97
+ assert in_channels % 4 == 0
98
+ super().__init__(
99
+ nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
100
+ nn.ReLU(),
101
+ nn.PixelUnshuffle(2)
102
+ )
103
+
104
+ def reset_parameters(self):
105
+ init_subpixel(self[0].weight)
106
+ nn.init.zeros_(self[0].bias)
107
+
108
+
109
+ def conv_bn_elu(in_, out_, kernel_size=3, stride=1, padding=True):
110
+ # conv layer with ELU activation function
111
+ pad = int(kernel_size/2)
112
+ if padding is False:
113
+ pad = 0
114
+ return nn.Sequential(
115
+ nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=pad),
116
+ nn.ELU(),
117
+ )
118
+
119
+
120
+ class Inference_Data(Dataset):
121
+ def __init__(self, img_path):
122
+ self.input_img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
123
+ self.input_img = cv2.resize(
124
+ self.input_img, (512, 256), interpolation=cv2.INTER_CUBIC)
125
+ self.to_tensor = transforms.ToTensor()
126
+ self.data_len = 1
127
+
128
+ def __getitem__(self, index):
129
+ self.tensor_img = self.to_tensor(self.input_img)
130
+ return self.tensor_img
131
+
132
+ def __len__(self):
133
+ return self.data_len
134
+
135
+
136
+ class SEBlock(nn.Module):
137
+ def __init__(self, channel, reducation=8):
138
+ super(SEBlock, self).__init__()
139
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
140
+ self.fc = nn.Sequential(
141
+ nn.Linear(channel, channel//reducation),
142
+ nn.ReLU(inplace=True),
143
+ nn.Linear(channel//reducation, channel),
144
+ nn.Sigmoid())
145
+
146
+ def forward(self, x, aux_inp=None):
147
+ b, c, w, h = x.size()
148
+
149
+ def scale(x):
150
+ return (x - x.min()) / (x.max() - x.min() + 1e-8)
151
+ y1 = self.avg_pool(x).view(b, c)
152
+ y = self.fc(y1).view(b, c, 1, 1)
153
+ r = x*y
154
+ if aux_inp is not None:
155
+ aux_weitghts = nn.AdaptiveAvgPool2d(aux_inp.shape[-1]//8)(aux_inp)
156
+ aux_weitghts = nn.Sigmoid()(aux_weitghts.mean(1, keepdim=True))
157
+ tmp = x*aux_weitghts
158
+ tmp_img = (tmp - tmp.min()) / (tmp.max() - tmp.min())
159
+ r += tmp
160
+
161
+ return r
162
+
163
+
164
+ class ConvTransposeUp(nn.Sequential):
165
+ def __init__(self, in_channels, out_channels, norm, kernel_size=3, stride=2, padding=1, activation=None):
166
+ super().__init__(
167
+ nn.ConvTranspose2d(in_channels, out_channels,
168
+ # output_padding=output_padding, dilation=dilation
169
+ kernel_size=kernel_size, padding=padding, stride=stride,
170
+ ),
171
+ norm(out_channels) if norm is not None else nn.Identity(),
172
+ activation() if activation is not None else nn.Identity(),
173
+ )
174
+
175
+
176
+ class SkipConnect(nn.Module):
177
+ """docstring for RegionalSkipConnect"""
178
+
179
+ def __init__(self, channel):
180
+ super(SkipConnect, self).__init__()
181
+ self.rconv = nn.Conv2d(channel*2, channel, 3, padding=1, bias=False)
182
+
183
+ def forward(self, feature):
184
+ return F.relu(self.rconv(feature))
185
+
186
+
187
+ class AttentionBlock(nn.Module):
188
+ def __init__(self, in_channels):
189
+ super(AttentionBlock, self).__init__()
190
+ self.attn = nn.Sequential(
191
+ nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=1),
192
+ nn.Sigmoid()
193
+ )
194
+
195
+ def forward(self, x):
196
+ return self.attn(x)
197
+
198
+
199
+ class PatchHarmonizerBlock(nn.Module):
200
+ def __init__(self, in_channels=3, grid_count=5):
201
+ super(PatchHarmonizerBlock, self).__init__()
202
+ self.patch_harmonizer = PatchedHarmonizer(grid_count=grid_count)
203
+ self.head = conv_bn(in_channels*2, in_channels,
204
+ kernel_size=3, padding=1, normalization=None)
205
+
206
+ def forward(self, fg, bg, mask):
207
+ fg_harm, _ = self.patch_harmonizer(fg, bg, mask)
208
+
209
+ return self.head(torch.cat([fg, fg_harm], 1))
210
+
211
+
212
+ class PHNet(nn.Module):
213
+ def __init__(self, enc_sizes=[3, 16, 32, 64, 128, 256, 512], skips=True, grid_count=[10, 5, 1], init_weights=[0.5, 0.5], init_value=0.8):
214
+ super(PHNet, self).__init__()
215
+ self.skips = skips
216
+ self.feature_extractor = PatchHarmonizerBlock(
217
+ in_channels=enc_sizes[0], grid_count=grid_count[1])
218
+ self.encoder = nn.ModuleList([
219
+ conv_bn(enc_sizes[0], enc_sizes[1],
220
+ kernel_size=4, stride=2),
221
+ conv_bn(enc_sizes[1], enc_sizes[2],
222
+ kernel_size=3, stride=1),
223
+ conv_bn(enc_sizes[2], enc_sizes[3],
224
+ kernel_size=4, stride=2),
225
+ conv_bn(enc_sizes[3], enc_sizes[4],
226
+ kernel_size=3, stride=1),
227
+ conv_bn(enc_sizes[4], enc_sizes[5],
228
+ kernel_size=4, stride=2),
229
+ conv_bn(enc_sizes[5], enc_sizes[6],
230
+ kernel_size=3, stride=1),
231
+ ])
232
+
233
+ dec_ins = enc_sizes[::-1]
234
+ dec_sizes = enc_sizes[::-1]
235
+ self.start_level = len(dec_sizes) - len(grid_count)
236
+ self.normalizers = nn.ModuleList([
237
+ PatchNormalizer(in_channels=dec_sizes[self.start_level+i], grid_count=count, weights=init_weights, eps=1e-7, init_value=init_value) for i, count in enumerate(grid_count)
238
+ ])
239
+
240
+ self.decoder = nn.ModuleList([
241
+ ConvTransposeUp(
242
+ dec_ins[0], dec_sizes[1], norm=nn.BatchNorm2d, kernel_size=3, stride=1, activation=nn.LeakyReLU),
243
+ ConvTransposeUp(
244
+ dec_ins[1], dec_sizes[2], norm=nn.BatchNorm2d, kernel_size=4, stride=2, activation=nn.LeakyReLU),
245
+ ConvTransposeUp(
246
+ dec_ins[2], dec_sizes[3], norm=nn.BatchNorm2d, kernel_size=3, stride=1, activation=nn.LeakyReLU),
247
+ ConvTransposeUp(
248
+ dec_ins[3], dec_sizes[4], norm=None, kernel_size=4, stride=2, activation=nn.LeakyReLU),
249
+ ConvTransposeUp(
250
+ dec_ins[4], dec_sizes[5], norm=None, kernel_size=3, stride=1, activation=nn.LeakyReLU),
251
+ ConvTransposeUp(
252
+ dec_ins[5], 3, norm=None, kernel_size=4, stride=2, activation=None),
253
+ ])
254
+
255
+ self.skip = nn.ModuleList([
256
+ SkipConnect(x) for x in dec_ins
257
+ ])
258
+
259
+ self.SE_block = SEBlock(enc_sizes[6])
260
+
261
+ def forward(self, img, mask):
262
+ x = img
263
+
264
+ enc_outs = [x]
265
+ x_harm = self.feature_extractor(x*mask, x*(1-mask), mask)
266
+
267
+ # x = x_harm
268
+ masks = [mask]
269
+ for i, down_layer in enumerate(self.encoder):
270
+ x = down_layer(x)
271
+ scale_factor = 1. / (pow(2, 1 - i % 2))
272
+ masks.append(F.interpolate(masks[-1], scale_factor=scale_factor))
273
+ enc_outs.append(x)
274
+
275
+ x = self.SE_block(x, aux_inp=x_harm)
276
+
277
+ masks = masks[::-1]
278
+ for i, (up_layer, enc_out) in enumerate(zip(self.decoder, enc_outs[::-1])):
279
+ if i >= self.start_level:
280
+ enc_out = self.normalizers[i -
281
+ self.start_level](enc_out, enc_out, masks[i])
282
+ x = torch.cat([x, enc_out], 1)
283
+ x = self.skip[i](x)
284
+ x = up_layer(x)
285
+
286
+ relighted = F.sigmoid(x)
287
+
288
+ return {
289
+ "harmonized": relighted, # target prediction
290
+ }
291
+
292
+ def set_requires_grad(self, modules=["encoder", "sh_head", "resquare", "decoder"], value=False):
293
+ for module in modules:
294
+ attr = getattr(self, module, None)
295
+ if attr is not None:
296
+ attr.requires_grad_(value)
tools/normalizer.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+ import tqdm
5
+ import time
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from .util import rgb_to_lab, lab_to_rgb
10
+
11
+
12
+ def blend(f, b, a):
13
+ return f*a + b*(1 - a)
14
+
15
+
16
+ class PatchedHarmonizer(nn.Module):
17
+ def __init__(self, grid_count=1, init_weights=[0.9, 0.1]):
18
+ super(PatchedHarmonizer, self).__init__()
19
+ self.eps = 1e-8
20
+ # self.weights = torch.nn.Parameter(torch.ones((grid_count, grid_count)), requires_grad=True)
21
+ # self.grid_weights_ = torch.nn.Parameter(torch.FloatTensor(init_weights), requires_grad=True)
22
+ self.grid_weights = torch.nn.Parameter(
23
+ torch.FloatTensor(init_weights), requires_grad=True)
24
+ # self.weights.retain_graph = True
25
+ self.grid_count = grid_count
26
+
27
+ def lab_shift(self, x, invert=False):
28
+ x = x.float()
29
+ if invert:
30
+ x[:, 0, :, :] /= 2.55
31
+ x[:, 1, :, :] -= 128
32
+ x[:, 2, :, :] -= 128
33
+ else:
34
+ x[:, 0, :, :] *= 2.55
35
+ x[:, 1, :, :] += 128
36
+ x[:, 2, :, :] += 128
37
+
38
+ return x
39
+
40
+ def get_mean_std(self, img, mask, dim=[2, 3]):
41
+ sum = torch.sum(img*mask, dim=dim) # (B, C)
42
+ num = torch.sum(mask, dim=dim) # (B, C)
43
+ mu = sum / (num + self.eps)
44
+ mean = mu[:, :, None, None]
45
+ var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps)
46
+ var = var[:, :, None, None]
47
+
48
+ return mean, torch.sqrt(var+self.eps)
49
+
50
+ def compute_patch_statistics(self, lab):
51
+ means, stds = [], []
52
+ bs, dx, dy = lab.shape[0], lab.shape[2] // self.grid_count, lab.shape[3] // self.grid_count
53
+ for h in range(self.grid_count):
54
+ cmeans, cstds = [], []
55
+ for w in range(self.grid_count):
56
+ ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
57
+ if h == self.grid_count - 1:
58
+ ind[1] = None
59
+ if w == self.grid_count - 1:
60
+ ind[-1] = None
61
+ m, v = self.compute_mean_var(
62
+ lab[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3])
63
+ cmeans.append(m)
64
+ cstds.append(v)
65
+ means.append(cmeans)
66
+ stds.append(cstds)
67
+
68
+ return means, stds
69
+
70
+ def compute_mean_var(self, x, dim=[1, 2]):
71
+ mean = x.float().mean(dim=dim)[:, :, None, None]
72
+ var = torch.sqrt(x.float().var(dim=dim))[:, :, None, None]
73
+
74
+ return mean, var
75
+
76
+ def forward(self, fg_rgb, bg_rgb, alpha, masked_stats=False):
77
+
78
+ bg_rgb = F.interpolate(bg_rgb, size=(
79
+ fg_rgb.shape[2:])) # b x C x H x W
80
+
81
+ bg_lab = bg_rgb # self.lab_shift(rgb_to_lab(bg_rgb/255.))
82
+ fg_lab = fg_rgb # self.lab_shift(rgb_to_lab(fg_rgb/255.))
83
+
84
+ if masked_stats:
85
+ self.bg_global_mean, self.bg_global_var = self.get_mean_std(
86
+ img=bg_lab, mask=(1-alpha))
87
+ self.fg_global_mean, self.fg_global_var = self.get_mean_std(
88
+ img=fg_lab, mask=torch.ones_like(alpha))
89
+ else:
90
+ self.bg_global_mean, self.bg_global_var = self.compute_mean_var(bg_lab, dim=[
91
+ 2, 3])
92
+ self.fg_global_mean, self.fg_global_var = self.compute_mean_var(fg_lab, dim=[
93
+ 2, 3])
94
+
95
+ self.bg_means, self.bg_vars = self.compute_patch_statistics(
96
+ bg_lab)
97
+ self.fg_means, self.fg_vars = self.compute_patch_statistics(
98
+ fg_lab)
99
+
100
+ fg_harm = self.harmonize(fg_lab)
101
+ # fg_harm = lab_to_rgb(fg_harm)
102
+ bg = F.interpolate(bg_rgb, size=(fg_rgb.shape[2:]))/255.
103
+
104
+ composite = blend(fg_harm, bg, alpha)
105
+
106
+ return composite, fg_harm
107
+
108
+ def harmonize(self, fg):
109
+ harmonized = torch.zeros_like(fg)
110
+ dx = fg.shape[2] // self.grid_count
111
+ dy = fg.shape[3] // self.grid_count
112
+ for h in range(self.grid_count):
113
+ for w in range(self.grid_count):
114
+ ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
115
+ if h == self.grid_count - 1:
116
+ ind[1] = None
117
+ if w == self.grid_count - 1:
118
+ ind[-1] = None
119
+ harmonized[:, :, ind[0]:ind[1], ind[2]:ind[3]] = self.normalize_channel(
120
+ fg[:, :, ind[0]:ind[1], ind[2]:ind[3]], h, w)
121
+
122
+ # harmonized = self.lab_shift(harmonized, invert=True)
123
+
124
+ return harmonized
125
+
126
+ def normalize_channel(self, value, h, w):
127
+
128
+ fg_local_mean, fg_local_var = self.fg_means[h][w], self.fg_vars[h][w]
129
+ bg_local_mean, bg_local_var = self.bg_means[h][w], self.bg_vars[h][w]
130
+ fg_global_mean, fg_global_var = self.fg_global_mean, self.fg_global_var
131
+ bg_global_mean, bg_global_var = self.bg_global_mean, self.bg_global_var
132
+
133
+ # global2global normalization
134
+ zeroed_mean = value - fg_global_mean
135
+ # (fg_v * div_global_v + (1-fg_v) * div_v)
136
+ scaled_var = zeroed_mean * (bg_global_var/(fg_global_var + self.eps))
137
+ normalized_global = scaled_var + bg_global_mean
138
+
139
+ # local2local normalization
140
+ zeroed_mean = value - fg_local_mean
141
+ # (fg_v * div_global_v + (1-fg_v) * div_v)
142
+ scaled_var = zeroed_mean * (bg_local_var/(fg_local_var + self.eps))
143
+ normalized_local = scaled_var + bg_local_mean
144
+
145
+ return self.grid_weights[0]*normalized_local + self.grid_weights[1]*normalized_global
146
+
147
+ def normalize_fg(self, value):
148
+ zeroed_mean = value - \
149
+ (self.fg_local_mean *
150
+ self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
151
+ # (fg_v * div_global_v + (1-fg_v) * div_v)
152
+ scaled_var = zeroed_mean * \
153
+ (self.bg_global_var/(self.fg_global_var + self.eps))
154
+ normalized_lg = scaled_var + \
155
+ (self.bg_local_mean *
156
+ self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
157
+
158
+ return normalized_lg
159
+
160
+
161
+ class PatchNormalizer(nn.Module):
162
+ def __init__(self, in_channels=3, eps=1e-7, grid_count=1, weights=[0.5, 0.5], init_value=1e-2):
163
+ super(PatchNormalizer, self).__init__()
164
+ self.grid_count = grid_count
165
+ self.eps = eps
166
+
167
+ self.weights = nn.Parameter(
168
+ torch.FloatTensor(weights), requires_grad=True)
169
+ self.fg_var = nn.Parameter(
170
+ init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
171
+ self.fg_bias = nn.Parameter(
172
+ init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
173
+ self.patched_fg_var = nn.Parameter(
174
+ init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
175
+ self.patched_fg_bias = nn.Parameter(
176
+ init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
177
+ self.bg_var = nn.Parameter(
178
+ init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
179
+ self.bg_bias = nn.Parameter(
180
+ init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
181
+ self.grid_weights = torch.nn.Parameter(torch.ones((in_channels, grid_count, grid_count))[
182
+ None, :, :, :] / (grid_count*grid_count*in_channels), requires_grad=True)
183
+
184
+ def local_normalization(self, value):
185
+ zeroed_mean = value - \
186
+ (self.fg_local_mean *
187
+ self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
188
+ # (fg_v * div_global_v + (1-fg_v) * div_v)
189
+ scaled_var = zeroed_mean * \
190
+ (self.bg_global_var/(self.fg_global_var + self.eps))
191
+ normalized_lg = scaled_var + \
192
+ (self.bg_local_mean *
193
+ self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
194
+
195
+ return normalized_lg
196
+
197
+ def get_mean_std(self, img, mask, dim=[2, 3]):
198
+ sum = torch.sum(img*mask, dim=dim) # (B, C)
199
+ num = torch.sum(mask, dim=dim) # (B, C)
200
+ mu = sum / (num + self.eps)
201
+ mean = mu[:, :, None, None]
202
+ var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps)
203
+ var = var[:, :, None, None]
204
+
205
+ return mean, torch.sqrt(var+self.eps)
206
+
207
+ def compute_patch_statistics(self, img, mask):
208
+ means, stds = [], []
209
+ bs, dx, dy = img.shape[0], img.shape[2] // self.grid_count, img.shape[3] // self.grid_count
210
+ for h in range(self.grid_count):
211
+ cmeans, cstds = [], []
212
+ for w in range(self.grid_count):
213
+ ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
214
+ if h == self.grid_count - 1:
215
+ ind[1] = None
216
+ if w == self.grid_count - 1:
217
+ ind[-1] = None
218
+ m, v = self.get_mean_std(
219
+ img[:, :, ind[0]:ind[1], ind[2]:ind[3]], mask[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3])
220
+ cmeans.append(m.reshape(m.shape[:2]))
221
+ cstds.append(v.reshape(v.shape[:2]))
222
+ means.append(torch.stack(cmeans))
223
+ stds.append(torch.stack(cstds))
224
+
225
+ return torch.stack(means), torch.stack(stds)
226
+
227
+ def compute_mean_var(self, x, dim=[2, 3]):
228
+ mean = x.float().mean(dim=dim)
229
+ var = torch.sqrt(x.float().var(dim=dim))
230
+
231
+ return mean, var
232
+
233
+ def forward(self, fg, bg, mask):
234
+
235
+ self.local_means, self.local_vars = self.compute_patch_statistics(
236
+ bg, (1-mask))
237
+
238
+ bg_mean, bg_var = self.get_mean_std(bg, 1 - mask)
239
+ zeroed_mean = (bg - bg_mean)
240
+ unscaled = zeroed_mean / bg_var
241
+ bg_normalized = unscaled * self.bg_var + self.bg_bias
242
+
243
+ fg_mean, fg_var = self.get_mean_std(fg, mask)
244
+ zeroed_mean = fg - fg_mean
245
+ unscaled = zeroed_mean / fg_var
246
+
247
+ mean_patched_back = (self.local_means.permute(
248
+ 2, 3, 0, 1)*self.grid_weights).sum(dim=[2, 3])[:, :, None, None]
249
+
250
+ normalized = unscaled * bg_var + bg_mean
251
+ patch_normalized = unscaled * bg_var + mean_patched_back
252
+
253
+ fg_normalized = normalized * self.fg_var + self.fg_bias
254
+ fg_patch_normalized = patch_normalized * \
255
+ self.patched_fg_var + self.patched_fg_bias
256
+
257
+ fg_result = self.weights[0] * fg_normalized + \
258
+ self.weights[1] * fg_patch_normalized
259
+ composite = blend(fg_result, bg_normalized, mask)
260
+
261
+ return composite
tools/stylematte.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from typing import List
8
+ from itertools import chain
9
+
10
+ from transformers import SegformerForSemanticSegmentation, Mask2FormerForUniversalSegmentation
11
+ device = 'cpu'
12
+
13
+
14
+ class EncoderDecoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ encoder,
18
+ decoder,
19
+ prefix=nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True),
20
+ ):
21
+ super().__init__()
22
+ self.encoder = encoder
23
+ self.decoder = decoder
24
+ self.prefix = prefix
25
+
26
+ def forward(self, x):
27
+ if self.prefix is not None:
28
+ x = self.prefix(x)
29
+ x = self.encoder(x)["hidden_states"] # transformers
30
+ return self.decoder(x)
31
+
32
+
33
+ def conv2d_relu(input_filters, output_filters, kernel_size=3, bias=True):
34
+ return nn.Sequential(
35
+ nn.Conv2d(input_filters, output_filters,
36
+ kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
37
+ nn.LeakyReLU(0.2, inplace=True),
38
+ nn.BatchNorm2d(output_filters)
39
+ )
40
+
41
+
42
+ def up_and_add(x, y):
43
+ return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
44
+
45
+
46
+ class FPN_fuse(nn.Module):
47
+ def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
48
+ super(FPN_fuse, self).__init__()
49
+ assert feature_channels[0] == fpn_out
50
+ self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
51
+ for ft_size in feature_channels[1:]])
52
+ self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
53
+ * (len(feature_channels)-1))
54
+ self.conv_fusion = nn.Sequential(
55
+ nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3,
56
+ padding=1, bias=False),
57
+ nn.BatchNorm2d(fpn_out),
58
+ nn.ReLU(inplace=True),
59
+ )
60
+
61
+ def forward(self, features):
62
+
63
+ features[:-1] = [conv1x1(feature) for feature,
64
+ conv1x1 in zip(features[:-1], self.conv1x1)]
65
+ feature = up_and_add(self.smooth_conv[0](features[0]), features[1])
66
+ feature = up_and_add(self.smooth_conv[1](feature), features[2])
67
+ feature = up_and_add(self.smooth_conv[2](feature), features[3])
68
+
69
+ H, W = features[-1].size(2), features[-1].size(3)
70
+ x = [feature, features[-1]]
71
+ x = [F.interpolate(x_el, size=(H, W), mode='bilinear',
72
+ align_corners=True) for x_el in x]
73
+
74
+ x = self.conv_fusion(torch.cat(x, dim=1))
75
+ # x = F.interpolate(x, size=(H*4, W*4), mode='bilinear', align_corners=True)
76
+ return x
77
+
78
+
79
+ class PSPModule(nn.Module):
80
+ # In the original inmplementation they use precise RoI pooling
81
+ # Instead of using adaptative average pooling
82
+ def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
83
+ super(PSPModule, self).__init__()
84
+ out_channels = in_channels // len(bin_sizes)
85
+ self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
86
+ for b_s in bin_sizes])
87
+ self.bottleneck = nn.Sequential(
88
+ nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
89
+ kernel_size=3, padding=1, bias=False),
90
+ nn.BatchNorm2d(in_channels),
91
+ nn.ReLU(inplace=True),
92
+ nn.Dropout2d(0.1)
93
+ )
94
+
95
+ def _make_stages(self, in_channels, out_channels, bin_sz):
96
+ prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
97
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
98
+ bn = nn.BatchNorm2d(out_channels)
99
+ relu = nn.ReLU(inplace=True)
100
+ return nn.Sequential(prior, conv, bn, relu)
101
+
102
+ def forward(self, features):
103
+ h, w = features.size()[2], features.size()[3]
104
+ pyramids = [features]
105
+ pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
106
+ align_corners=True) for stage in self.stages])
107
+ output = self.bottleneck(torch.cat(pyramids, dim=1))
108
+ return output
109
+
110
+
111
+ class UperNet_swin(nn.Module):
112
+ # Implementing only the object path
113
+ def __init__(self, backbone, pretrained=True):
114
+ super(UperNet_swin, self).__init__()
115
+
116
+ self.backbone = backbone
117
+ feature_channels = [192, 384, 768, 768]
118
+ self.PPN = PSPModule(feature_channels[-1])
119
+ self.FPN = FPN_fuse(feature_channels, fpn_out=feature_channels[0])
120
+ self.head = nn.Conv2d(feature_channels[0], 1, kernel_size=3, padding=1)
121
+
122
+ def forward(self, x):
123
+ input_size = (x.size()[2], x.size()[3])
124
+ features = self.backbone(x)["hidden_states"]
125
+ features[-1] = self.PPN(features[-1])
126
+ x = self.head(self.FPN(features))
127
+
128
+ x = F.interpolate(x, size=input_size, mode='bilinear')
129
+ return x
130
+
131
+ def get_backbone_params(self):
132
+ return self.backbone.parameters()
133
+
134
+ def get_decoder_params(self):
135
+ return chain(self.PPN.parameters(), self.FPN.parameters(), self.head.parameters())
136
+
137
+
138
+ class UnetDecoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ encoder_channels=(3, 192, 384, 768, 768),
142
+ decoder_channels=(512, 256, 128, 64),
143
+ n_blocks=4,
144
+ use_batchnorm=True,
145
+ attention_type=None,
146
+ center=False,
147
+ ):
148
+ super().__init__()
149
+
150
+ if n_blocks != len(decoder_channels):
151
+ raise ValueError(
152
+ "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
153
+ n_blocks, len(decoder_channels)
154
+ )
155
+ )
156
+
157
+ # remove first skip with same spatial resolution
158
+ encoder_channels = encoder_channels[1:]
159
+ # reverse channels to start from head of encoder
160
+ encoder_channels = encoder_channels[::-1]
161
+
162
+ # computing blocks input and output channels
163
+ head_channels = encoder_channels[0]
164
+ in_channels = [head_channels] + list(decoder_channels[:-1])
165
+ skip_channels = list(encoder_channels[1:]) + [0]
166
+
167
+ out_channels = decoder_channels
168
+
169
+ if center:
170
+ self.center = CenterBlock(
171
+ head_channels, head_channels, use_batchnorm=use_batchnorm)
172
+ else:
173
+ self.center = nn.Identity()
174
+
175
+ # combine decoder keyword arguments
176
+ kwargs = dict(use_batchnorm=use_batchnorm,
177
+ attention_type=attention_type)
178
+ blocks = [
179
+ DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
180
+ for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
181
+ ]
182
+ self.blocks = nn.ModuleList(blocks)
183
+ upscale_factor = 4
184
+ self.matting_head = nn.Sequential(
185
+ nn.Conv2d(64, 1, kernel_size=3, padding=1),
186
+ nn.ReLU(),
187
+ nn.UpsamplingBilinear2d(scale_factor=upscale_factor),
188
+ )
189
+
190
+ def preprocess_features(self, x):
191
+ features = []
192
+ for out_tensor in x:
193
+ bs, n, f = out_tensor.size()
194
+ h = int(n**0.5)
195
+ feature = out_tensor.view(-1, h, h,
196
+ f).permute(0, 3, 1, 2).contiguous()
197
+ features.append(feature)
198
+ return features
199
+
200
+ def forward(self, features):
201
+ # remove first skip with same spatial resolution
202
+ features = features[1:]
203
+ # reverse channels to start from head of encoder
204
+ features = features[::-1]
205
+
206
+ features = self.preprocess_features(features)
207
+
208
+ head = features[0]
209
+ skips = features[1:]
210
+
211
+ x = self.center(head)
212
+ for i, decoder_block in enumerate(self.blocks):
213
+ skip = skips[i] if i < len(skips) else None
214
+ x = decoder_block(x, skip)
215
+ # y_i = self.upsample1(y_i)
216
+ # hypercol = torch.cat([y0,y1,y2,y3,y4], dim=1)
217
+ x = self.matting_head(x)
218
+ x = 1-nn.ReLU()(1-x)
219
+ return x
220
+
221
+
222
+ class SegmentationHead(nn.Sequential):
223
+ def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
224
+ conv2d = nn.Conv2d(in_channels, out_channels,
225
+ kernel_size=kernel_size, padding=kernel_size // 2)
226
+ upsampling = nn.UpsamplingBilinear2d(
227
+ scale_factor=upsampling) if upsampling > 1 else nn.Identity()
228
+ super().__init__(conv2d, upsampling)
229
+
230
+
231
+ class DecoderBlock(nn.Module):
232
+ def __init__(
233
+ self,
234
+ in_channels,
235
+ skip_channels,
236
+ out_channels,
237
+ use_batchnorm=True,
238
+ attention_type=None,
239
+ ):
240
+ super().__init__()
241
+ self.conv1 = conv2d_relu(
242
+ in_channels + skip_channels,
243
+ out_channels,
244
+ kernel_size=3
245
+ )
246
+ self.conv2 = conv2d_relu(
247
+ out_channels,
248
+ out_channels,
249
+ kernel_size=3,
250
+ )
251
+ self.in_channels = in_channels
252
+ self.out_channels = out_channels
253
+ self.skip_channels = skip_channels
254
+
255
+ def forward(self, x, skip=None):
256
+ if skip is None:
257
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
258
+ else:
259
+ if x.shape[-1] != skip.shape[-1]:
260
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
261
+ if skip is not None:
262
+ # print(x.shape,skip.shape)
263
+ x = torch.cat([x, skip], dim=1)
264
+ x = self.conv1(x)
265
+ x = self.conv2(x)
266
+ return x
267
+
268
+
269
+ class CenterBlock(nn.Sequential):
270
+ def __init__(self, in_channels, out_channels):
271
+ conv1 = conv2d_relu(
272
+ in_channels,
273
+ out_channels,
274
+ kernel_size=3,
275
+ )
276
+ conv2 = conv2d_relu(
277
+ out_channels,
278
+ out_channels,
279
+ kernel_size=3,
280
+ )
281
+ super().__init__(conv1, conv2)
282
+
283
+
284
+ class SegForm(nn.Module):
285
+ def __init__(self):
286
+ super(SegForm, self).__init__()
287
+ # configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
288
+ # configuration.num_labels = 1 ## set output as 1
289
+ # self.model = SegformerForSemanticSegmentation(config=configuration)
290
+
291
+ self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
292
+ )
293
+
294
+ def forward(self, image):
295
+ img_segs = self.model(image)
296
+ upsampled_logits = nn.functional.interpolate(img_segs.logits,
297
+ scale_factor=4,
298
+ mode='nearest',
299
+ )
300
+ return upsampled_logits
301
+
302
+
303
+ class StyleMatte(nn.Module):
304
+ def __init__(self):
305
+ super(StyleMatte, self).__init__()
306
+ # configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
307
+ # configuration.num_labels = 1 ## set output as 1
308
+ self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
309
+ self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained(
310
+ "facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
311
+ self.fgf = FastGuidedFilter()
312
+ self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
313
+ # self.mean = torch.Tensor([0.43216, 0.394666, 0.37645]).float().view(-1, 1, 1)
314
+ # self.register_buffer('image_net_mean', self.mean)
315
+ # self.std = torch.Tensor([0.22803, 0.22145, 0.216989]).float().view(-1, 1, 1)
316
+ # self.register_buffer('image_net_std', self.std)
317
+
318
+ def forward(self, image, normalize=False):
319
+ # if normalize:
320
+ # image.sub_(self.get_buffer("image_net_mean")).div_(self.get_buffer("image_net_std"))
321
+
322
+ decoder_out = self.pixel_decoder(image)
323
+ decoder_states = list(decoder_out.decoder_hidden_states)
324
+ decoder_states.append(decoder_out.decoder_last_hidden_state)
325
+ out_pure = self.fpn(decoder_states)
326
+
327
+ image_lr = nn.functional.interpolate(image.mean(1, keepdim=True),
328
+ scale_factor=0.25,
329
+ mode='bicubic',
330
+ align_corners=True
331
+ )
332
+ out = self.conv(out_pure)
333
+ out = self.fgf(image_lr, out, image.mean(
334
+ 1, keepdim=True)) # .clip(0,1)
335
+ # out = nn.Sigmoid()(out)
336
+ # out = nn.functional.interpolate(out,
337
+ # scale_factor=4,
338
+ # mode='bicubic',
339
+ # align_corners=True
340
+ # )
341
+
342
+ return torch.sigmoid(out)
343
+
344
+ def get_training_params(self):
345
+ # +list(self.fgf.parameters())
346
+ return list(self.fpn.parameters())+list(self.conv.parameters())
347
+
348
+
349
+ class GuidedFilter(nn.Module):
350
+ def __init__(self, r, eps=1e-8):
351
+ super(GuidedFilter, self).__init__()
352
+
353
+ self.r = r
354
+ self.eps = eps
355
+ self.boxfilter = BoxFilter(r)
356
+
357
+ def forward(self, x, y):
358
+ n_x, c_x, h_x, w_x = x.size()
359
+ n_y, c_y, h_y, w_y = y.size()
360
+
361
+ assert n_x == n_y
362
+ assert c_x == 1 or c_x == c_y
363
+ assert h_x == h_y and w_x == w_y
364
+ assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
365
+
366
+ # N
367
+ N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
368
+
369
+ # mean_x
370
+ mean_x = self.boxfilter(x) / N
371
+ # mean_y
372
+ mean_y = self.boxfilter(y) / N
373
+ # cov_xy
374
+ cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
375
+ # var_x
376
+ var_x = self.boxfilter(x * x) / N - mean_x * mean_x
377
+
378
+ # A
379
+ A = cov_xy / (var_x + self.eps)
380
+ # b
381
+ b = mean_y - A * mean_x
382
+
383
+ # mean_A; mean_b
384
+ mean_A = self.boxfilter(A) / N
385
+ mean_b = self.boxfilter(b) / N
386
+
387
+ return mean_A * x + mean_b
388
+
389
+
390
+ class FastGuidedFilter(nn.Module):
391
+ def __init__(self, r=1, eps=1e-8):
392
+ super(FastGuidedFilter, self).__init__()
393
+
394
+ self.r = r
395
+ self.eps = eps
396
+ self.boxfilter = BoxFilter(r)
397
+
398
+ def forward(self, lr_x, lr_y, hr_x):
399
+ n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
400
+ n_lry, c_lry, h_lry, w_lry = lr_y.size()
401
+ n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
402
+
403
+ assert n_lrx == n_lry and n_lry == n_hrx
404
+ assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
405
+ assert h_lrx == h_lry and w_lrx == w_lry
406
+ assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
407
+
408
+ # N
409
+ N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))
410
+
411
+ # mean_x
412
+ mean_x = self.boxfilter(lr_x) / N
413
+ # mean_y
414
+ mean_y = self.boxfilter(lr_y) / N
415
+ # cov_xy
416
+ cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
417
+ # var_x
418
+ var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
419
+
420
+ # A
421
+ A = cov_xy / (var_x + self.eps)
422
+ # b
423
+ b = mean_y - A * mean_x
424
+
425
+ # mean_A; mean_b
426
+ mean_A = F.interpolate(
427
+ A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
428
+ mean_b = F.interpolate(
429
+ b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
430
+
431
+ return mean_A*hr_x+mean_b
432
+
433
+
434
+ class DeepGuidedFilterRefiner(nn.Module):
435
+ def __init__(self, hid_channels=16):
436
+ super().__init__()
437
+ self.box_filter = nn.Conv2d(
438
+ 4, 4, kernel_size=3, padding=1, bias=False, groups=4)
439
+ self.box_filter.weight.data[...] = 1 / 9
440
+ self.conv = nn.Sequential(
441
+ nn.Conv2d(4 * 2 + hid_channels, hid_channels,
442
+ kernel_size=1, bias=False),
443
+ nn.BatchNorm2d(hid_channels),
444
+ nn.ReLU(True),
445
+ nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
446
+ nn.BatchNorm2d(hid_channels),
447
+ nn.ReLU(True),
448
+ nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
449
+ )
450
+
451
+ def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
452
+ fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
453
+ base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
454
+ base_y = torch.cat([base_fgr, base_pha], dim=1)
455
+
456
+ mean_x = self.box_filter(base_x)
457
+ mean_y = self.box_filter(base_y)
458
+ cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
459
+ var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
460
+
461
+ A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
462
+ b = mean_y - A * mean_x
463
+
464
+ H, W = fine_src.shape[2:]
465
+ A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
466
+ b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
467
+
468
+ out = A * fine_x + b
469
+ fgr, pha = out.split([3, 1], dim=1)
470
+ return fgr, pha
471
+
472
+
473
+ def diff_x(input, r):
474
+ assert input.dim() == 4
475
+
476
+ left = input[:, :, r:2 * r + 1]
477
+ middle = input[:, :, 2 * r + 1:] - input[:, :, :-2 * r - 1]
478
+ right = input[:, :, -1:] - input[:, :, -2 * r - 1: -r - 1]
479
+
480
+ output = torch.cat([left, middle, right], dim=2)
481
+
482
+ return output
483
+
484
+
485
+ def diff_y(input, r):
486
+ assert input.dim() == 4
487
+
488
+ left = input[:, :, :, r:2 * r + 1]
489
+ middle = input[:, :, :, 2 * r + 1:] - input[:, :, :, :-2 * r - 1]
490
+ right = input[:, :, :, -1:] - input[:, :, :, -2 * r - 1: -r - 1]
491
+
492
+ output = torch.cat([left, middle, right], dim=3)
493
+
494
+ return output
495
+
496
+
497
+ class BoxFilter(nn.Module):
498
+ def __init__(self, r):
499
+ super(BoxFilter, self).__init__()
500
+
501
+ self.r = r
502
+
503
+ def forward(self, x):
504
+ assert x.dim() == 4
505
+
506
+ return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
tools/util.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from typing import Tuple
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision.utils import make_grid
7
+ import cv2
8
+ from torchvision import transforms, models
9
+
10
+
11
+ def log(msg, lvl='info'):
12
+ if lvl == 'info':
13
+ print(f"***********{msg}****************")
14
+ if lvl == 'error':
15
+ print(f"!!! Exception: {msg} !!!")
16
+
17
+
18
+ def lab_shift(x, invert=False):
19
+ x = x.float()
20
+ if invert:
21
+ x[:, 0, :, :] /= 2.55
22
+ x[:, 1, :, :] -= 128
23
+ x[:, 2, :, :] -= 128
24
+ else:
25
+ x[:, 0, :, :] *= 2.55
26
+ x[:, 1, :, :] += 128
27
+ x[:, 2, :, :] += 128
28
+
29
+ return x
30
+
31
+
32
+ def calculate_psnr(img1, img2):
33
+ # img1 and img2 have range [0, 255]
34
+ img1 = img1.astype(np.float64)
35
+ img2 = img2.astype(np.float64)
36
+ mse = np.mean((img1 - img2)**2)
37
+ if mse == 0:
38
+ return float('inf')
39
+
40
+ return 20 * math.log10(255.0 / math.sqrt(mse))
41
+
42
+
43
+ def calculate_fpsnr(fmse):
44
+ return 10 * math.log10(255.0 / (fmse + 1e-8))
45
+
46
+
47
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1), bit=8):
48
+ '''
49
+ Converts a torch Tensor into an image Numpy array
50
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
51
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
52
+ '''
53
+ norm = float(2**bit) - 1
54
+ # print('before', tensor[:,:,0].max(), tensor[:,:,0].min(), '\t', tensor[:,:,1].max(), tensor[:,:,1].min(), '\t', tensor[:,:,2].max(), tensor[:,:,2].min())
55
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
56
+ # print('clamp ', tensor[:,:,0].max(), tensor[:,:,0].min(), '\t', tensor[:,:,1].max(), tensor[:,:,1].min(), '\t', tensor[:,:,2].max(), tensor[:,:,2].min())
57
+ tensor = (tensor - min_max[0]) / \
58
+ (min_max[1] - min_max[0]) # to range [0,1]
59
+ n_dim = tensor.dim()
60
+ if n_dim == 4:
61
+ n_img = len(tensor)
62
+ img_np = make_grid(tensor, nrow=int(
63
+ math.sqrt(n_img)), normalize=False).numpy()
64
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
65
+ elif n_dim == 3:
66
+ img_np = tensor.numpy()
67
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
68
+ elif n_dim == 2:
69
+ img_np = tensor.numpy()
70
+ else:
71
+ raise TypeError(
72
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
73
+ if out_type == np.uint8:
74
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
75
+ img_np = (img_np * norm).round()
76
+ return img_np.astype(out_type)
77
+
78
+
79
+ def rgb_to_lab(image: torch.Tensor) -> torch.Tensor:
80
+ r"""Convert a RGB image to Lab.
81
+
82
+ .. image:: _static/img/rgb_to_lab.png
83
+
84
+ The input RGB image is assumed to be in the range of :math:`[0, 1]`. Lab
85
+ color is computed using the D65 illuminant and Observer 2.
86
+
87
+ Args:
88
+ image: RGB Image to be converted to Lab with shape :math:`(*, 3, H, W)`.
89
+
90
+ Returns:
91
+ Lab version of the image with shape :math:`(*, 3, H, W)`.
92
+ The L channel values are in the range 0..100. a and b are in the range -128..127.
93
+
94
+ Example:
95
+ >>> input = torch.rand(2, 3, 4, 5)
96
+ >>> output = rgb_to_lab(input) # 2x3x4x5
97
+ """
98
+ if not isinstance(image, torch.Tensor):
99
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
100
+
101
+ if len(image.shape) < 3 or image.shape[-3] != 3:
102
+ raise ValueError(
103
+ f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
104
+
105
+ # Convert from sRGB to Linear RGB
106
+ lin_rgb = rgb_to_linear_rgb(image)
107
+
108
+ xyz_im: torch.Tensor = rgb_to_xyz(lin_rgb)
109
+
110
+ # normalize for D65 white point
111
+ xyz_ref_white = torch.tensor(
112
+ [0.95047, 1.0, 1.08883], device=xyz_im.device, dtype=xyz_im.dtype)[..., :, None, None]
113
+ xyz_normalized = torch.div(xyz_im, xyz_ref_white)
114
+
115
+ threshold = 0.008856
116
+ power = torch.pow(xyz_normalized.clamp(min=threshold), 1 / 3.0)
117
+ scale = 7.787 * xyz_normalized + 4.0 / 29.0
118
+ xyz_int = torch.where(xyz_normalized > threshold, power, scale)
119
+
120
+ x: torch.Tensor = xyz_int[..., 0, :, :]
121
+ y: torch.Tensor = xyz_int[..., 1, :, :]
122
+ z: torch.Tensor = xyz_int[..., 2, :, :]
123
+
124
+ L: torch.Tensor = (116.0 * y) - 16.0
125
+ a: torch.Tensor = 500.0 * (x - y)
126
+ _b: torch.Tensor = 200.0 * (y - z)
127
+
128
+ out: torch.Tensor = torch.stack([L, a, _b], dim=-3)
129
+
130
+ return out
131
+
132
+
133
+ def lab_to_rgb(image: torch.Tensor, clip: bool = True) -> torch.Tensor:
134
+ r"""Convert a Lab image to RGB.
135
+
136
+ The L channel is assumed to be in the range of :math:`[0, 100]`.
137
+ a and b channels are in the range of :math:`[-128, 127]`.
138
+
139
+ Args:
140
+ image: Lab image to be converted to RGB with shape :math:`(*, 3, H, W)`.
141
+ clip: Whether to apply clipping to insure output RGB values in range :math:`[0, 1]`.
142
+
143
+ Returns:
144
+ Lab version of the image with shape :math:`(*, 3, H, W)`.
145
+ The output RGB image are in the range of :math:`[0, 1]`.
146
+
147
+ Example:
148
+ >>> input = torch.rand(2, 3, 4, 5)
149
+ >>> output = lab_to_rgb(input) # 2x3x4x5
150
+ """
151
+ if not isinstance(image, torch.Tensor):
152
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
153
+
154
+ if len(image.shape) < 3 or image.shape[-3] != 3:
155
+ raise ValueError(
156
+ f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
157
+
158
+ L: torch.Tensor = image[..., 0, :, :]
159
+ a: torch.Tensor = image[..., 1, :, :]
160
+ _b: torch.Tensor = image[..., 2, :, :]
161
+
162
+ fy = (L + 16.0) / 116.0
163
+ fx = (a / 500.0) + fy
164
+ fz = fy - (_b / 200.0)
165
+
166
+ # if color data out of range: Z < 0
167
+ fz = fz.clamp(min=0.0)
168
+
169
+ fxyz = torch.stack([fx, fy, fz], dim=-3)
170
+
171
+ # Convert from Lab to XYZ
172
+ power = torch.pow(fxyz, 3.0)
173
+ scale = (fxyz - 4.0 / 29.0) / 7.787
174
+ xyz = torch.where(fxyz > 0.2068966, power, scale)
175
+
176
+ # For D65 white point
177
+ xyz_ref_white = torch.tensor(
178
+ [0.95047, 1.0, 1.08883], device=xyz.device, dtype=xyz.dtype)[..., :, None, None]
179
+ xyz_im = xyz * xyz_ref_white
180
+
181
+ rgbs_im: torch.Tensor = xyz_to_rgb(xyz_im)
182
+
183
+ # https://github.com/richzhang/colorization-pytorch/blob/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/util.py#L107
184
+ # rgbs_im = torch.where(rgbs_im < 0, torch.zeros_like(rgbs_im), rgbs_im)
185
+
186
+ # Convert from RGB Linear to sRGB
187
+ rgb_im = linear_rgb_to_rgb(rgbs_im)
188
+
189
+ # Clip to 0,1 https://www.w3.org/Graphics/Color/srgb
190
+ if clip:
191
+ rgb_im = torch.clamp(rgb_im, min=0.0, max=1.0)
192
+
193
+ return rgb_im
194
+
195
+
196
+ def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor:
197
+ r"""Convert a RGB image to XYZ.
198
+
199
+ .. image:: _static/img/rgb_to_xyz.png
200
+
201
+ Args:
202
+ image: RGB Image to be converted to XYZ with shape :math:`(*, 3, H, W)`.
203
+
204
+ Returns:
205
+ XYZ version of the image with shape :math:`(*, 3, H, W)`.
206
+
207
+ Example:
208
+ >>> input = torch.rand(2, 3, 4, 5)
209
+ >>> output = rgb_to_xyz(input) # 2x3x4x5
210
+ """
211
+ if not isinstance(image, torch.Tensor):
212
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
213
+
214
+ if len(image.shape) < 3 or image.shape[-3] != 3:
215
+ raise ValueError(
216
+ f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
217
+
218
+ r: torch.Tensor = image[..., 0, :, :]
219
+ g: torch.Tensor = image[..., 1, :, :]
220
+ b: torch.Tensor = image[..., 2, :, :]
221
+
222
+ x: torch.Tensor = 0.412453 * r + 0.357580 * g + 0.180423 * b
223
+ y: torch.Tensor = 0.212671 * r + 0.715160 * g + 0.072169 * b
224
+ z: torch.Tensor = 0.019334 * r + 0.119193 * g + 0.950227 * b
225
+
226
+ out: torch.Tensor = torch.stack([x, y, z], -3)
227
+
228
+ return out
229
+
230
+
231
+ def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor:
232
+ r"""Convert a XYZ image to RGB.
233
+
234
+ Args:
235
+ image: XYZ Image to be converted to RGB with shape :math:`(*, 3, H, W)`.
236
+
237
+ Returns:
238
+ RGB version of the image with shape :math:`(*, 3, H, W)`.
239
+
240
+ Example:
241
+ >>> input = torch.rand(2, 3, 4, 5)
242
+ >>> output = xyz_to_rgb(input) # 2x3x4x5
243
+ """
244
+ if not isinstance(image, torch.Tensor):
245
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
246
+
247
+ if len(image.shape) < 3 or image.shape[-3] != 3:
248
+ raise ValueError(
249
+ f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
250
+
251
+ x: torch.Tensor = image[..., 0, :, :]
252
+ y: torch.Tensor = image[..., 1, :, :]
253
+ z: torch.Tensor = image[..., 2, :, :]
254
+
255
+ r: torch.Tensor = 3.2404813432005266 * x + - \
256
+ 1.5371515162713185 * y + -0.4985363261688878 * z
257
+ g: torch.Tensor = -0.9692549499965682 * x + \
258
+ 1.8759900014898907 * y + 0.0415559265582928 * z
259
+ b: torch.Tensor = 0.0556466391351772 * x + - \
260
+ 0.2040413383665112 * y + 1.0573110696453443 * z
261
+
262
+ out: torch.Tensor = torch.stack([r, g, b], dim=-3)
263
+
264
+ return out
265
+
266
+
267
+ def rgb_to_linear_rgb(image: torch.Tensor) -> torch.Tensor:
268
+ r"""Convert an sRGB image to linear RGB. Used in colorspace conversions.
269
+
270
+ .. image:: _static/img/rgb_to_linear_rgb.png
271
+
272
+ Args:
273
+ image: sRGB Image to be converted to linear RGB of shape :math:`(*,3,H,W)`.
274
+
275
+ Returns:
276
+ linear RGB version of the image with shape of :math:`(*,3,H,W)`.
277
+
278
+ Example:
279
+ >>> input = torch.rand(2, 3, 4, 5)
280
+ >>> output = rgb_to_linear_rgb(input) # 2x3x4x5
281
+ """
282
+ if not isinstance(image, torch.Tensor):
283
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
284
+
285
+ if len(image.shape) < 3 or image.shape[-3] != 3:
286
+ raise ValueError(
287
+ f"Input size must have a shape of (*, 3, H, W).Got {image.shape}")
288
+
289
+ lin_rgb: torch.Tensor = torch.where(image > 0.04045, torch.pow(
290
+ ((image + 0.055) / 1.055), 2.4), image / 12.92)
291
+
292
+ return lin_rgb
293
+
294
+
295
+ def linear_rgb_to_rgb(image: torch.Tensor) -> torch.Tensor:
296
+ r"""Convert a linear RGB image to sRGB. Used in colorspace conversions.
297
+
298
+ Args:
299
+ image: linear RGB Image to be converted to sRGB of shape :math:`(*,3,H,W)`.
300
+
301
+ Returns:
302
+ sRGB version of the image with shape of shape :math:`(*,3,H,W)`.
303
+
304
+ Example:
305
+ >>> input = torch.rand(2, 3, 4, 5)
306
+ >>> output = linear_rgb_to_rgb(input) # 2x3x4x5
307
+ """
308
+ if not isinstance(image, torch.Tensor):
309
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
310
+
311
+ if len(image.shape) < 3 or image.shape[-3] != 3:
312
+ raise ValueError(
313
+ f"Input size must have a shape of (*, 3, H, W).Got {image.shape}")
314
+
315
+ threshold = 0.0031308
316
+ rgb: torch.Tensor = torch.where(
317
+ image > threshold, 1.055 *
318
+ torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image
319
+ )
320
+
321
+ return rgb
322
+
323
+
324
+ def inference_img(model, img, device='cpu'):
325
+ h, w, _ = img.shape
326
+ # print(img.shape)
327
+ if h % 8 != 0 or w % 8 != 0:
328
+ img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w %
329
+ 8, 0, cv2.BORDER_REFLECT)
330
+ # print(img.shape)
331
+
332
+ tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
333
+ input_t = tensor_img
334
+ input_t = input_t/255.0
335
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
336
+ std=[0.229, 0.224, 0.225])
337
+ input_t = normalize(input_t)
338
+ input_t = input_t.unsqueeze(0).float()
339
+ with torch.no_grad():
340
+ out = model(input_t)
341
+ # print("out",out.shape)
342
+ result = out[0][:, -h:, -w:].cpu().numpy()
343
+ # print(result.shape)
344
+
345
+ return result[0]