Spaces:
Running
Running
befozg
commited on
Commit
·
f0de4e8
1
Parent(s):
9967c2f
added initial portrait transfer app
Browse files- .gitignore +174 -0
- app.py +107 -0
- requirements.txt +38 -0
- slider.html +137 -0
- tools/__init__.py +3 -0
- tools/inference.py +56 -0
- tools/model.py +296 -0
- tools/normalizer.py +261 -0
- tools/stylematte.py +506 -0
- tools/util.py +345 -0
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/#use-with-ide
|
111 |
+
.pdm.toml
|
112 |
+
|
113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
114 |
+
__pypackages__/
|
115 |
+
|
116 |
+
# Celery stuff
|
117 |
+
celerybeat-schedule
|
118 |
+
celerybeat.pid
|
119 |
+
|
120 |
+
# SageMath parsed files
|
121 |
+
*.sage.py
|
122 |
+
|
123 |
+
# Environments
|
124 |
+
.env
|
125 |
+
.venv
|
126 |
+
env/
|
127 |
+
venv/
|
128 |
+
ENV/
|
129 |
+
env.bak/
|
130 |
+
venv.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
#.idea/
|
162 |
+
|
163 |
+
config/*
|
164 |
+
trainer/__pycache__/
|
165 |
+
trainer/__pycache__/*
|
166 |
+
__pycache__/*
|
167 |
+
checkpoints/*.pth
|
168 |
+
*/*.pth
|
169 |
+
*/checkpoints/best_pure.pth
|
170 |
+
checkpoints/best_pure.pth
|
171 |
+
*.ipynb
|
172 |
+
.ipynb_checkpoints/*
|
173 |
+
flagged/
|
174 |
+
assets/
|
app.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from tools import Inference, Matting, log
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import numpy as np
|
7 |
+
import torchvision.transforms.functional as tf
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
args = OmegaConf.load(os.path.join(f"./config/test.yaml"))
|
11 |
+
|
12 |
+
global_comp = None
|
13 |
+
global_mask = None
|
14 |
+
|
15 |
+
log("Model loading")
|
16 |
+
phnet = Inference(**args)
|
17 |
+
stylematte = Matting(**args)
|
18 |
+
log("Model loaded")
|
19 |
+
|
20 |
+
|
21 |
+
def harmonize(comp, mask):
|
22 |
+
log("Inference started")
|
23 |
+
if comp is None or mask is None:
|
24 |
+
log("Empty source")
|
25 |
+
return np.zeros((16, 16, 3))
|
26 |
+
|
27 |
+
comp = comp.convert('RGB')
|
28 |
+
mask = mask.convert('1')
|
29 |
+
in_shape = comp.size[::-1]
|
30 |
+
|
31 |
+
comp = tf.resize(comp, [args.image_size, args.image_size])
|
32 |
+
mask = tf.resize(mask, [args.image_size, args.image_size])
|
33 |
+
|
34 |
+
compt = tf.to_tensor(comp)
|
35 |
+
maskt = tf.to_tensor(mask)
|
36 |
+
res = phnet.harmonize(compt, maskt)
|
37 |
+
res = tf.resize(res, in_shape)
|
38 |
+
|
39 |
+
log("Inference finished")
|
40 |
+
|
41 |
+
return np.uint8((res*255)[0].permute(1, 2, 0).numpy())
|
42 |
+
|
43 |
+
|
44 |
+
def extract_matte(img, back):
|
45 |
+
mask, fg = stylematte.extract(img)
|
46 |
+
fg_pil = Image.fromarray(np.uint8(fg))
|
47 |
+
|
48 |
+
composite = fg + (1 - mask[:, :, None]) * \
|
49 |
+
np.array(back.resize(mask.shape[::-1]))
|
50 |
+
composite_pil = Image.fromarray(np.uint8(composite))
|
51 |
+
|
52 |
+
global_comp = composite_pil
|
53 |
+
global_mask = mask
|
54 |
+
|
55 |
+
return [composite_pil, mask, fg_pil]
|
56 |
+
|
57 |
+
|
58 |
+
def css(height=3, scale=2):
|
59 |
+
return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}"
|
60 |
+
|
61 |
+
|
62 |
+
with gr.Blocks() as demo:
|
63 |
+
gr.Markdown(
|
64 |
+
"""
|
65 |
+
# Welcome to portrait transfer demo app!
|
66 |
+
Select source portrait image and new background.
|
67 |
+
""")
|
68 |
+
btn_compose = gr.Button(value="Compose")
|
69 |
+
|
70 |
+
with gr.Row():
|
71 |
+
input_ui = gr.Image(
|
72 |
+
type="numpy", label='Source image to extract foreground')
|
73 |
+
back_ui = gr.Image(type="pil", label='The new background')
|
74 |
+
|
75 |
+
gr.Examples(
|
76 |
+
examples=[["./assets/comp.jpg", "./assets/back.jpg"]],
|
77 |
+
inputs=[input_ui, back_ui],
|
78 |
+
)
|
79 |
+
|
80 |
+
gr.Markdown(
|
81 |
+
"""
|
82 |
+
## Resulting alpha matte and extracted foreground.
|
83 |
+
""")
|
84 |
+
with gr.Row():
|
85 |
+
matte_ui = gr.Image(type="pil", label='Alpha matte')
|
86 |
+
fg_ui = gr.Image(type="pil", image_mode='RGBA',
|
87 |
+
label='Extracted foreground')
|
88 |
+
|
89 |
+
gr.Markdown(
|
90 |
+
"""
|
91 |
+
## Click the button and compare the composite with the harmonized version.
|
92 |
+
""")
|
93 |
+
btn_harmonize = gr.Button(value="Harmonize composite")
|
94 |
+
|
95 |
+
with gr.Row():
|
96 |
+
composite_ui = gr.Image(type="pil", label='Composite')
|
97 |
+
harmonized_ui = gr.Image(
|
98 |
+
type="pil", label='Harmonized composite', css=css(3, 3))
|
99 |
+
|
100 |
+
btn_compose.click(extract_matte, inputs=[input_ui, back_ui], outputs=[
|
101 |
+
composite_ui, matte_ui, fg_ui])
|
102 |
+
btn_harmonize.click(harmonize, inputs=[
|
103 |
+
composite_ui, matte_ui], outputs=[harmonized_ui])
|
104 |
+
|
105 |
+
|
106 |
+
log("Interface created")
|
107 |
+
demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.30.0
|
2 |
+
gradio_client==0.2.4
|
3 |
+
huggingface-hub==0.14.1
|
4 |
+
imageio==2.25.1
|
5 |
+
imgcat==0.5.0
|
6 |
+
ipykernel==6.16.0
|
7 |
+
ipython==8.5.0
|
8 |
+
ipywidgets==8.0.2
|
9 |
+
kiwisolver==1.4.2
|
10 |
+
kornia==0.6.9
|
11 |
+
legacy==0.1.6
|
12 |
+
numpy==1.21.6
|
13 |
+
omegaconf==2.2.3
|
14 |
+
opencv-python==4.5.5.62
|
15 |
+
opencv-python-headless==4.7.0.68
|
16 |
+
packaging==21.3
|
17 |
+
pandas==1.4.2
|
18 |
+
parso==0.8.3
|
19 |
+
Pillow==9.4.0
|
20 |
+
protobuf==3.20.1
|
21 |
+
Pygments==2.13.0
|
22 |
+
PyMatting==1.1.8
|
23 |
+
pyparsing==3.0.9
|
24 |
+
pyrsistent==0.19.3
|
25 |
+
scikit-image==0.19.3
|
26 |
+
scikit-learn==1.1.1
|
27 |
+
scipy==1.10.0
|
28 |
+
seaborn==0.12.2
|
29 |
+
sklearn==0.0
|
30 |
+
sniffio==1.3.0
|
31 |
+
soupsieve==2.4
|
32 |
+
timm==0.6.12
|
33 |
+
torch==1.11.0
|
34 |
+
torchaudio==0.11.0
|
35 |
+
torchvision==0.12.0
|
36 |
+
tornado==6.2
|
37 |
+
tqdm==4.64.1
|
38 |
+
transformers==4.28.1
|
slider.html
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
5 |
+
<style>
|
6 |
+
* {box-sizing: border-box;}
|
7 |
+
|
8 |
+
.img-comp-container {
|
9 |
+
position: relative;
|
10 |
+
height: 200px; /*should be the same height as the images*/
|
11 |
+
}
|
12 |
+
|
13 |
+
.img-comp-img {
|
14 |
+
position: absolute;
|
15 |
+
width: auto;
|
16 |
+
height: auto;
|
17 |
+
overflow:hidden;
|
18 |
+
}
|
19 |
+
|
20 |
+
.img-comp-img img {
|
21 |
+
display:block;
|
22 |
+
vertical-align:middle;
|
23 |
+
}
|
24 |
+
|
25 |
+
.img-comp-slider {
|
26 |
+
position: absolute;
|
27 |
+
z-index:9;
|
28 |
+
cursor: ew-resize;
|
29 |
+
/*set the appearance of the slider:*/
|
30 |
+
width: 40px;
|
31 |
+
height: 40px;
|
32 |
+
background-color: #2196F3;
|
33 |
+
opacity: 0.7;
|
34 |
+
border-radius: 50%;
|
35 |
+
}
|
36 |
+
</style>
|
37 |
+
<script>
|
38 |
+
function initComparisons() {
|
39 |
+
var x, i;
|
40 |
+
/*find all elements with an "overlay" class:*/
|
41 |
+
x = document.getElementsByClassName("img-comp-overlay");
|
42 |
+
for (i = 0; i < x.length; i++) {
|
43 |
+
/*once for each "overlay" element:
|
44 |
+
pass the "overlay" element as a parameter when executing the compareImages function:*/
|
45 |
+
compareImages(x[i]);
|
46 |
+
}
|
47 |
+
function compareImages(img) {
|
48 |
+
var slider, img, clicked = 0, w, h;
|
49 |
+
/*get the width and height of the img element*/
|
50 |
+
w = img.offsetWidth;
|
51 |
+
h = img.offsetHeight;
|
52 |
+
/*set the width of the img element to 50%:*/
|
53 |
+
img.style.width = (w / 2) + "px";
|
54 |
+
/*create slider:*/
|
55 |
+
slider = document.createElement("DIV");
|
56 |
+
slider.setAttribute("class", "img-comp-slider");
|
57 |
+
/*insert slider*/
|
58 |
+
img.parentElement.insertBefore(slider, img);
|
59 |
+
/*position the slider in the middle:*/
|
60 |
+
slider.style.top = (h / 2) - (slider.offsetHeight / 2) + "px";
|
61 |
+
slider.style.left = (w / 2) - (slider.offsetWidth / 2) + "px";
|
62 |
+
/*execute a function when the mouse button is pressed:*/
|
63 |
+
slider.addEventListener("mousedown", slideReady);
|
64 |
+
/*and another function when the mouse button is released:*/
|
65 |
+
window.addEventListener("mouseup", slideFinish);
|
66 |
+
/*or touched (for touch screens:*/
|
67 |
+
slider.addEventListener("touchstart", slideReady);
|
68 |
+
/*and released (for touch screens:*/
|
69 |
+
window.addEventListener("touchend", slideFinish);
|
70 |
+
function slideReady(e) {
|
71 |
+
/*prevent any other actions that may occur when moving over the image:*/
|
72 |
+
e.preventDefault();
|
73 |
+
/*the slider is now clicked and ready to move:*/
|
74 |
+
clicked = 1;
|
75 |
+
/*execute a function when the slider is moved:*/
|
76 |
+
window.addEventListener("mousemove", slideMove);
|
77 |
+
window.addEventListener("touchmove", slideMove);
|
78 |
+
}
|
79 |
+
function slideFinish() {
|
80 |
+
/*the slider is no longer clicked:*/
|
81 |
+
clicked = 0;
|
82 |
+
}
|
83 |
+
function slideMove(e) {
|
84 |
+
var pos;
|
85 |
+
/*if the slider is no longer clicked, exit this function:*/
|
86 |
+
if (clicked == 0) return false;
|
87 |
+
/*get the cursor's x position:*/
|
88 |
+
pos = getCursorPos(e)
|
89 |
+
/*prevent the slider from being positioned outside the image:*/
|
90 |
+
if (pos < 0) pos = 0;
|
91 |
+
if (pos > w) pos = w;
|
92 |
+
/*execute a function that will resize the overlay image according to the cursor:*/
|
93 |
+
slide(pos);
|
94 |
+
}
|
95 |
+
function getCursorPos(e) {
|
96 |
+
var a, x = 0;
|
97 |
+
e = (e.changedTouches) ? e.changedTouches[0] : e;
|
98 |
+
/*get the x positions of the image:*/
|
99 |
+
a = img.getBoundingClientRect();
|
100 |
+
/*calculate the cursor's x coordinate, relative to the image:*/
|
101 |
+
x = e.pageX - a.left;
|
102 |
+
/*consider any page scrolling:*/
|
103 |
+
x = x - window.pageXOffset;
|
104 |
+
return x;
|
105 |
+
}
|
106 |
+
function slide(x) {
|
107 |
+
/*resize the image:*/
|
108 |
+
img.style.width = x + "px";
|
109 |
+
/*position the slider:*/
|
110 |
+
slider.style.left = img.offsetWidth - (slider.offsetWidth / 2) + "px";
|
111 |
+
}
|
112 |
+
}
|
113 |
+
}
|
114 |
+
</script>
|
115 |
+
</head>
|
116 |
+
<body>
|
117 |
+
|
118 |
+
<h1>Compare Two Images</h1>
|
119 |
+
|
120 |
+
<p>Click and slide the blue slider to compare two images:</p>
|
121 |
+
|
122 |
+
<div class="img-comp-container">
|
123 |
+
<div class="img-comp-img">
|
124 |
+
<img src="img_snow.jpg" width="300" height="200">
|
125 |
+
</div>
|
126 |
+
<div class="img-comp-img img-comp-overlay">
|
127 |
+
<img src="img_forest.jpg" width="300" height="200">
|
128 |
+
</div>
|
129 |
+
</div>
|
130 |
+
|
131 |
+
<script>
|
132 |
+
/*Execute a function that will execute an image compare function for each element with the img-comp-overlay class:*/
|
133 |
+
initComparisons();
|
134 |
+
</script>
|
135 |
+
|
136 |
+
</body>
|
137 |
+
</html>
|
tools/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .inference import Inference
|
2 |
+
from .inference import Matting
|
3 |
+
from .util import log
|
tools/inference.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .model import PHNet
|
3 |
+
import torchvision.transforms.functional as tf
|
4 |
+
from .util import inference_img, log
|
5 |
+
from .stylematte import StyleMatte
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class Inference:
|
10 |
+
def __init__(self, **kwargs):
|
11 |
+
self.rank = 0
|
12 |
+
self.__dict__.update(kwargs)
|
13 |
+
self.model = PHNet(enc_sizes=self.enc_sizes,
|
14 |
+
skips=self.skips,
|
15 |
+
grid_count=self.grid_counts,
|
16 |
+
init_weights=self.init_weights,
|
17 |
+
init_value=self.init_value)
|
18 |
+
log(f"checkpoint: {self.checkpoint.harmonizer}")
|
19 |
+
state = torch.load(self.checkpoint.harmonizer,
|
20 |
+
map_location=self.device)
|
21 |
+
|
22 |
+
self.model.load_state_dict(state, strict=True)
|
23 |
+
self.model.eval()
|
24 |
+
|
25 |
+
def harmonize(self, composite, mask):
|
26 |
+
if len(composite.shape) < 4:
|
27 |
+
composite = composite.unsqueeze(0)
|
28 |
+
while len(mask.shape) < 4:
|
29 |
+
mask = mask.unsqueeze(0)
|
30 |
+
composite = tf.resize(composite, [self.image_size, self.image_size])
|
31 |
+
mask = tf.resize(mask, [self.image_size, self.image_size])
|
32 |
+
log(composite.shape, mask.shape)
|
33 |
+
with torch.no_grad():
|
34 |
+
harmonized = self.model(composite, mask)['harmonized']
|
35 |
+
|
36 |
+
result = harmonized * mask + composite * (1-mask)
|
37 |
+
print(result.shape)
|
38 |
+
return result
|
39 |
+
|
40 |
+
|
41 |
+
class Matting:
|
42 |
+
def __init__(self, **kwargs):
|
43 |
+
self.rank = 0
|
44 |
+
self.__dict__.update(kwargs)
|
45 |
+
self.model = StyleMatte().to(self.device)
|
46 |
+
log(f"checkpoint: {self.checkpoint.matting}")
|
47 |
+
state = torch.load(self.checkpoint.matting, map_location=self.device)
|
48 |
+
self.model.load_state_dict(state, strict=True)
|
49 |
+
self.model.eval()
|
50 |
+
|
51 |
+
def extract(self, inp):
|
52 |
+
mask = inference_img(self.model, inp, self.device)
|
53 |
+
inp_np = np.array(inp)
|
54 |
+
fg = mask[:, :, None]*inp_np
|
55 |
+
|
56 |
+
return [mask, fg]
|
tools/model.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from matplotlib import pyplot as plt
|
2 |
+
# from shtools import shReconstructSignal
|
3 |
+
from torchvision import transforms, utils
|
4 |
+
# from torchvision.ops import SqueezeExcitation
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch
|
9 |
+
import math
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
from .normalizer import PatchNormalizer, PatchedHarmonizer
|
13 |
+
from .util import rgb_to_lab, lab_to_rgb, lab_shift
|
14 |
+
|
15 |
+
# from shtools import *
|
16 |
+
# from color_converters import luv_to_rgb, rgb_to_luv
|
17 |
+
# from skimage import io, transform
|
18 |
+
'''
|
19 |
+
Input (256,512,3)
|
20 |
+
'''
|
21 |
+
|
22 |
+
|
23 |
+
def inpaint_bg(comp, mask, dim=[2, 3]):
|
24 |
+
"""
|
25 |
+
inpaint bg for ihd
|
26 |
+
Args:
|
27 |
+
comp (torch.float): [0:1]
|
28 |
+
mask (torch.float): [0:1]
|
29 |
+
"""
|
30 |
+
back = comp * (1-mask) # *255
|
31 |
+
sum = torch.sum(back, dim=dim) # (B, C)
|
32 |
+
num = torch.sum((1-mask), dim=dim) # (B, C)
|
33 |
+
mu = sum / (num)
|
34 |
+
mean = mu[:, :, None, None]
|
35 |
+
back = back + mask * mean
|
36 |
+
|
37 |
+
return back
|
38 |
+
|
39 |
+
|
40 |
+
class ConvTransposeUp(nn.Sequential):
|
41 |
+
def __init__(self, in_channels, out_channels, kernel_size=4, padding=1, stride=2, activation=None):
|
42 |
+
super().__init__(
|
43 |
+
nn.ConvTranspose2d(in_channels, out_channels,
|
44 |
+
kernel_size=kernel_size, padding=padding, stride=stride),
|
45 |
+
activation() if activation is not None else nn.Identity(),
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
class UpsampleShuffle(nn.Sequential):
|
50 |
+
def __init__(self, in_channels, out_channels, activation=True):
|
51 |
+
super().__init__(
|
52 |
+
nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
|
53 |
+
nn.GELU() if activation else nn.Identity(),
|
54 |
+
nn.PixelShuffle(2)
|
55 |
+
)
|
56 |
+
|
57 |
+
def reset_parameters(self):
|
58 |
+
init_subpixel(self[0].weight)
|
59 |
+
nn.init.zeros_(self[0].bias)
|
60 |
+
|
61 |
+
|
62 |
+
class UpsampleResize(nn.Sequential):
|
63 |
+
def __init__(self, in_channels, out_channels, out_size=None, activation=None, scale_factor=2., mode='bilinear'):
|
64 |
+
super().__init__(
|
65 |
+
nn.Upsample(scale_factor=scale_factor, mode=mode) if out_size is None else nn.Upsample(
|
66 |
+
out_size, mode=mode),
|
67 |
+
nn.ReflectionPad2d(1),
|
68 |
+
nn.Conv2d(in_channels, out_channels,
|
69 |
+
kernel_size=3, stride=1, padding=0),
|
70 |
+
activation() if activation is not None else nn.Identity(),
|
71 |
+
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def conv_bn(in_, out_, kernel_size=3, stride=1, padding=1, activation=nn.ReLU, normalization=nn.InstanceNorm2d):
|
76 |
+
|
77 |
+
return nn.Sequential(
|
78 |
+
nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=padding),
|
79 |
+
normalization(out_) if normalization is not None else nn.Identity(),
|
80 |
+
activation(),
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
def init_subpixel(weight):
|
85 |
+
co, ci, h, w = weight.shape
|
86 |
+
co2 = co // 4
|
87 |
+
# initialize sub kernel
|
88 |
+
k = torch.empty([c02, ci, h, w])
|
89 |
+
nn.init.kaiming_uniform_(k)
|
90 |
+
# repeat 4 times
|
91 |
+
k = k.repeat_interleave(4, dim=0)
|
92 |
+
weight.data.copy_(k)
|
93 |
+
|
94 |
+
|
95 |
+
class DownsampleShuffle(nn.Sequential):
|
96 |
+
def __init__(self, in_channels):
|
97 |
+
assert in_channels % 4 == 0
|
98 |
+
super().__init__(
|
99 |
+
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
|
100 |
+
nn.ReLU(),
|
101 |
+
nn.PixelUnshuffle(2)
|
102 |
+
)
|
103 |
+
|
104 |
+
def reset_parameters(self):
|
105 |
+
init_subpixel(self[0].weight)
|
106 |
+
nn.init.zeros_(self[0].bias)
|
107 |
+
|
108 |
+
|
109 |
+
def conv_bn_elu(in_, out_, kernel_size=3, stride=1, padding=True):
|
110 |
+
# conv layer with ELU activation function
|
111 |
+
pad = int(kernel_size/2)
|
112 |
+
if padding is False:
|
113 |
+
pad = 0
|
114 |
+
return nn.Sequential(
|
115 |
+
nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=pad),
|
116 |
+
nn.ELU(),
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
class Inference_Data(Dataset):
|
121 |
+
def __init__(self, img_path):
|
122 |
+
self.input_img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
123 |
+
self.input_img = cv2.resize(
|
124 |
+
self.input_img, (512, 256), interpolation=cv2.INTER_CUBIC)
|
125 |
+
self.to_tensor = transforms.ToTensor()
|
126 |
+
self.data_len = 1
|
127 |
+
|
128 |
+
def __getitem__(self, index):
|
129 |
+
self.tensor_img = self.to_tensor(self.input_img)
|
130 |
+
return self.tensor_img
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return self.data_len
|
134 |
+
|
135 |
+
|
136 |
+
class SEBlock(nn.Module):
|
137 |
+
def __init__(self, channel, reducation=8):
|
138 |
+
super(SEBlock, self).__init__()
|
139 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
140 |
+
self.fc = nn.Sequential(
|
141 |
+
nn.Linear(channel, channel//reducation),
|
142 |
+
nn.ReLU(inplace=True),
|
143 |
+
nn.Linear(channel//reducation, channel),
|
144 |
+
nn.Sigmoid())
|
145 |
+
|
146 |
+
def forward(self, x, aux_inp=None):
|
147 |
+
b, c, w, h = x.size()
|
148 |
+
|
149 |
+
def scale(x):
|
150 |
+
return (x - x.min()) / (x.max() - x.min() + 1e-8)
|
151 |
+
y1 = self.avg_pool(x).view(b, c)
|
152 |
+
y = self.fc(y1).view(b, c, 1, 1)
|
153 |
+
r = x*y
|
154 |
+
if aux_inp is not None:
|
155 |
+
aux_weitghts = nn.AdaptiveAvgPool2d(aux_inp.shape[-1]//8)(aux_inp)
|
156 |
+
aux_weitghts = nn.Sigmoid()(aux_weitghts.mean(1, keepdim=True))
|
157 |
+
tmp = x*aux_weitghts
|
158 |
+
tmp_img = (tmp - tmp.min()) / (tmp.max() - tmp.min())
|
159 |
+
r += tmp
|
160 |
+
|
161 |
+
return r
|
162 |
+
|
163 |
+
|
164 |
+
class ConvTransposeUp(nn.Sequential):
|
165 |
+
def __init__(self, in_channels, out_channels, norm, kernel_size=3, stride=2, padding=1, activation=None):
|
166 |
+
super().__init__(
|
167 |
+
nn.ConvTranspose2d(in_channels, out_channels,
|
168 |
+
# output_padding=output_padding, dilation=dilation
|
169 |
+
kernel_size=kernel_size, padding=padding, stride=stride,
|
170 |
+
),
|
171 |
+
norm(out_channels) if norm is not None else nn.Identity(),
|
172 |
+
activation() if activation is not None else nn.Identity(),
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
class SkipConnect(nn.Module):
|
177 |
+
"""docstring for RegionalSkipConnect"""
|
178 |
+
|
179 |
+
def __init__(self, channel):
|
180 |
+
super(SkipConnect, self).__init__()
|
181 |
+
self.rconv = nn.Conv2d(channel*2, channel, 3, padding=1, bias=False)
|
182 |
+
|
183 |
+
def forward(self, feature):
|
184 |
+
return F.relu(self.rconv(feature))
|
185 |
+
|
186 |
+
|
187 |
+
class AttentionBlock(nn.Module):
|
188 |
+
def __init__(self, in_channels):
|
189 |
+
super(AttentionBlock, self).__init__()
|
190 |
+
self.attn = nn.Sequential(
|
191 |
+
nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=1),
|
192 |
+
nn.Sigmoid()
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
return self.attn(x)
|
197 |
+
|
198 |
+
|
199 |
+
class PatchHarmonizerBlock(nn.Module):
|
200 |
+
def __init__(self, in_channels=3, grid_count=5):
|
201 |
+
super(PatchHarmonizerBlock, self).__init__()
|
202 |
+
self.patch_harmonizer = PatchedHarmonizer(grid_count=grid_count)
|
203 |
+
self.head = conv_bn(in_channels*2, in_channels,
|
204 |
+
kernel_size=3, padding=1, normalization=None)
|
205 |
+
|
206 |
+
def forward(self, fg, bg, mask):
|
207 |
+
fg_harm, _ = self.patch_harmonizer(fg, bg, mask)
|
208 |
+
|
209 |
+
return self.head(torch.cat([fg, fg_harm], 1))
|
210 |
+
|
211 |
+
|
212 |
+
class PHNet(nn.Module):
|
213 |
+
def __init__(self, enc_sizes=[3, 16, 32, 64, 128, 256, 512], skips=True, grid_count=[10, 5, 1], init_weights=[0.5, 0.5], init_value=0.8):
|
214 |
+
super(PHNet, self).__init__()
|
215 |
+
self.skips = skips
|
216 |
+
self.feature_extractor = PatchHarmonizerBlock(
|
217 |
+
in_channels=enc_sizes[0], grid_count=grid_count[1])
|
218 |
+
self.encoder = nn.ModuleList([
|
219 |
+
conv_bn(enc_sizes[0], enc_sizes[1],
|
220 |
+
kernel_size=4, stride=2),
|
221 |
+
conv_bn(enc_sizes[1], enc_sizes[2],
|
222 |
+
kernel_size=3, stride=1),
|
223 |
+
conv_bn(enc_sizes[2], enc_sizes[3],
|
224 |
+
kernel_size=4, stride=2),
|
225 |
+
conv_bn(enc_sizes[3], enc_sizes[4],
|
226 |
+
kernel_size=3, stride=1),
|
227 |
+
conv_bn(enc_sizes[4], enc_sizes[5],
|
228 |
+
kernel_size=4, stride=2),
|
229 |
+
conv_bn(enc_sizes[5], enc_sizes[6],
|
230 |
+
kernel_size=3, stride=1),
|
231 |
+
])
|
232 |
+
|
233 |
+
dec_ins = enc_sizes[::-1]
|
234 |
+
dec_sizes = enc_sizes[::-1]
|
235 |
+
self.start_level = len(dec_sizes) - len(grid_count)
|
236 |
+
self.normalizers = nn.ModuleList([
|
237 |
+
PatchNormalizer(in_channels=dec_sizes[self.start_level+i], grid_count=count, weights=init_weights, eps=1e-7, init_value=init_value) for i, count in enumerate(grid_count)
|
238 |
+
])
|
239 |
+
|
240 |
+
self.decoder = nn.ModuleList([
|
241 |
+
ConvTransposeUp(
|
242 |
+
dec_ins[0], dec_sizes[1], norm=nn.BatchNorm2d, kernel_size=3, stride=1, activation=nn.LeakyReLU),
|
243 |
+
ConvTransposeUp(
|
244 |
+
dec_ins[1], dec_sizes[2], norm=nn.BatchNorm2d, kernel_size=4, stride=2, activation=nn.LeakyReLU),
|
245 |
+
ConvTransposeUp(
|
246 |
+
dec_ins[2], dec_sizes[3], norm=nn.BatchNorm2d, kernel_size=3, stride=1, activation=nn.LeakyReLU),
|
247 |
+
ConvTransposeUp(
|
248 |
+
dec_ins[3], dec_sizes[4], norm=None, kernel_size=4, stride=2, activation=nn.LeakyReLU),
|
249 |
+
ConvTransposeUp(
|
250 |
+
dec_ins[4], dec_sizes[5], norm=None, kernel_size=3, stride=1, activation=nn.LeakyReLU),
|
251 |
+
ConvTransposeUp(
|
252 |
+
dec_ins[5], 3, norm=None, kernel_size=4, stride=2, activation=None),
|
253 |
+
])
|
254 |
+
|
255 |
+
self.skip = nn.ModuleList([
|
256 |
+
SkipConnect(x) for x in dec_ins
|
257 |
+
])
|
258 |
+
|
259 |
+
self.SE_block = SEBlock(enc_sizes[6])
|
260 |
+
|
261 |
+
def forward(self, img, mask):
|
262 |
+
x = img
|
263 |
+
|
264 |
+
enc_outs = [x]
|
265 |
+
x_harm = self.feature_extractor(x*mask, x*(1-mask), mask)
|
266 |
+
|
267 |
+
# x = x_harm
|
268 |
+
masks = [mask]
|
269 |
+
for i, down_layer in enumerate(self.encoder):
|
270 |
+
x = down_layer(x)
|
271 |
+
scale_factor = 1. / (pow(2, 1 - i % 2))
|
272 |
+
masks.append(F.interpolate(masks[-1], scale_factor=scale_factor))
|
273 |
+
enc_outs.append(x)
|
274 |
+
|
275 |
+
x = self.SE_block(x, aux_inp=x_harm)
|
276 |
+
|
277 |
+
masks = masks[::-1]
|
278 |
+
for i, (up_layer, enc_out) in enumerate(zip(self.decoder, enc_outs[::-1])):
|
279 |
+
if i >= self.start_level:
|
280 |
+
enc_out = self.normalizers[i -
|
281 |
+
self.start_level](enc_out, enc_out, masks[i])
|
282 |
+
x = torch.cat([x, enc_out], 1)
|
283 |
+
x = self.skip[i](x)
|
284 |
+
x = up_layer(x)
|
285 |
+
|
286 |
+
relighted = F.sigmoid(x)
|
287 |
+
|
288 |
+
return {
|
289 |
+
"harmonized": relighted, # target prediction
|
290 |
+
}
|
291 |
+
|
292 |
+
def set_requires_grad(self, modules=["encoder", "sh_head", "resquare", "decoder"], value=False):
|
293 |
+
for module in modules:
|
294 |
+
attr = getattr(self, module, None)
|
295 |
+
if attr is not None:
|
296 |
+
attr.requires_grad_(value)
|
tools/normalizer.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
import tqdm
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from .util import rgb_to_lab, lab_to_rgb
|
10 |
+
|
11 |
+
|
12 |
+
def blend(f, b, a):
|
13 |
+
return f*a + b*(1 - a)
|
14 |
+
|
15 |
+
|
16 |
+
class PatchedHarmonizer(nn.Module):
|
17 |
+
def __init__(self, grid_count=1, init_weights=[0.9, 0.1]):
|
18 |
+
super(PatchedHarmonizer, self).__init__()
|
19 |
+
self.eps = 1e-8
|
20 |
+
# self.weights = torch.nn.Parameter(torch.ones((grid_count, grid_count)), requires_grad=True)
|
21 |
+
# self.grid_weights_ = torch.nn.Parameter(torch.FloatTensor(init_weights), requires_grad=True)
|
22 |
+
self.grid_weights = torch.nn.Parameter(
|
23 |
+
torch.FloatTensor(init_weights), requires_grad=True)
|
24 |
+
# self.weights.retain_graph = True
|
25 |
+
self.grid_count = grid_count
|
26 |
+
|
27 |
+
def lab_shift(self, x, invert=False):
|
28 |
+
x = x.float()
|
29 |
+
if invert:
|
30 |
+
x[:, 0, :, :] /= 2.55
|
31 |
+
x[:, 1, :, :] -= 128
|
32 |
+
x[:, 2, :, :] -= 128
|
33 |
+
else:
|
34 |
+
x[:, 0, :, :] *= 2.55
|
35 |
+
x[:, 1, :, :] += 128
|
36 |
+
x[:, 2, :, :] += 128
|
37 |
+
|
38 |
+
return x
|
39 |
+
|
40 |
+
def get_mean_std(self, img, mask, dim=[2, 3]):
|
41 |
+
sum = torch.sum(img*mask, dim=dim) # (B, C)
|
42 |
+
num = torch.sum(mask, dim=dim) # (B, C)
|
43 |
+
mu = sum / (num + self.eps)
|
44 |
+
mean = mu[:, :, None, None]
|
45 |
+
var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps)
|
46 |
+
var = var[:, :, None, None]
|
47 |
+
|
48 |
+
return mean, torch.sqrt(var+self.eps)
|
49 |
+
|
50 |
+
def compute_patch_statistics(self, lab):
|
51 |
+
means, stds = [], []
|
52 |
+
bs, dx, dy = lab.shape[0], lab.shape[2] // self.grid_count, lab.shape[3] // self.grid_count
|
53 |
+
for h in range(self.grid_count):
|
54 |
+
cmeans, cstds = [], []
|
55 |
+
for w in range(self.grid_count):
|
56 |
+
ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
|
57 |
+
if h == self.grid_count - 1:
|
58 |
+
ind[1] = None
|
59 |
+
if w == self.grid_count - 1:
|
60 |
+
ind[-1] = None
|
61 |
+
m, v = self.compute_mean_var(
|
62 |
+
lab[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3])
|
63 |
+
cmeans.append(m)
|
64 |
+
cstds.append(v)
|
65 |
+
means.append(cmeans)
|
66 |
+
stds.append(cstds)
|
67 |
+
|
68 |
+
return means, stds
|
69 |
+
|
70 |
+
def compute_mean_var(self, x, dim=[1, 2]):
|
71 |
+
mean = x.float().mean(dim=dim)[:, :, None, None]
|
72 |
+
var = torch.sqrt(x.float().var(dim=dim))[:, :, None, None]
|
73 |
+
|
74 |
+
return mean, var
|
75 |
+
|
76 |
+
def forward(self, fg_rgb, bg_rgb, alpha, masked_stats=False):
|
77 |
+
|
78 |
+
bg_rgb = F.interpolate(bg_rgb, size=(
|
79 |
+
fg_rgb.shape[2:])) # b x C x H x W
|
80 |
+
|
81 |
+
bg_lab = bg_rgb # self.lab_shift(rgb_to_lab(bg_rgb/255.))
|
82 |
+
fg_lab = fg_rgb # self.lab_shift(rgb_to_lab(fg_rgb/255.))
|
83 |
+
|
84 |
+
if masked_stats:
|
85 |
+
self.bg_global_mean, self.bg_global_var = self.get_mean_std(
|
86 |
+
img=bg_lab, mask=(1-alpha))
|
87 |
+
self.fg_global_mean, self.fg_global_var = self.get_mean_std(
|
88 |
+
img=fg_lab, mask=torch.ones_like(alpha))
|
89 |
+
else:
|
90 |
+
self.bg_global_mean, self.bg_global_var = self.compute_mean_var(bg_lab, dim=[
|
91 |
+
2, 3])
|
92 |
+
self.fg_global_mean, self.fg_global_var = self.compute_mean_var(fg_lab, dim=[
|
93 |
+
2, 3])
|
94 |
+
|
95 |
+
self.bg_means, self.bg_vars = self.compute_patch_statistics(
|
96 |
+
bg_lab)
|
97 |
+
self.fg_means, self.fg_vars = self.compute_patch_statistics(
|
98 |
+
fg_lab)
|
99 |
+
|
100 |
+
fg_harm = self.harmonize(fg_lab)
|
101 |
+
# fg_harm = lab_to_rgb(fg_harm)
|
102 |
+
bg = F.interpolate(bg_rgb, size=(fg_rgb.shape[2:]))/255.
|
103 |
+
|
104 |
+
composite = blend(fg_harm, bg, alpha)
|
105 |
+
|
106 |
+
return composite, fg_harm
|
107 |
+
|
108 |
+
def harmonize(self, fg):
|
109 |
+
harmonized = torch.zeros_like(fg)
|
110 |
+
dx = fg.shape[2] // self.grid_count
|
111 |
+
dy = fg.shape[3] // self.grid_count
|
112 |
+
for h in range(self.grid_count):
|
113 |
+
for w in range(self.grid_count):
|
114 |
+
ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
|
115 |
+
if h == self.grid_count - 1:
|
116 |
+
ind[1] = None
|
117 |
+
if w == self.grid_count - 1:
|
118 |
+
ind[-1] = None
|
119 |
+
harmonized[:, :, ind[0]:ind[1], ind[2]:ind[3]] = self.normalize_channel(
|
120 |
+
fg[:, :, ind[0]:ind[1], ind[2]:ind[3]], h, w)
|
121 |
+
|
122 |
+
# harmonized = self.lab_shift(harmonized, invert=True)
|
123 |
+
|
124 |
+
return harmonized
|
125 |
+
|
126 |
+
def normalize_channel(self, value, h, w):
|
127 |
+
|
128 |
+
fg_local_mean, fg_local_var = self.fg_means[h][w], self.fg_vars[h][w]
|
129 |
+
bg_local_mean, bg_local_var = self.bg_means[h][w], self.bg_vars[h][w]
|
130 |
+
fg_global_mean, fg_global_var = self.fg_global_mean, self.fg_global_var
|
131 |
+
bg_global_mean, bg_global_var = self.bg_global_mean, self.bg_global_var
|
132 |
+
|
133 |
+
# global2global normalization
|
134 |
+
zeroed_mean = value - fg_global_mean
|
135 |
+
# (fg_v * div_global_v + (1-fg_v) * div_v)
|
136 |
+
scaled_var = zeroed_mean * (bg_global_var/(fg_global_var + self.eps))
|
137 |
+
normalized_global = scaled_var + bg_global_mean
|
138 |
+
|
139 |
+
# local2local normalization
|
140 |
+
zeroed_mean = value - fg_local_mean
|
141 |
+
# (fg_v * div_global_v + (1-fg_v) * div_v)
|
142 |
+
scaled_var = zeroed_mean * (bg_local_var/(fg_local_var + self.eps))
|
143 |
+
normalized_local = scaled_var + bg_local_mean
|
144 |
+
|
145 |
+
return self.grid_weights[0]*normalized_local + self.grid_weights[1]*normalized_global
|
146 |
+
|
147 |
+
def normalize_fg(self, value):
|
148 |
+
zeroed_mean = value - \
|
149 |
+
(self.fg_local_mean *
|
150 |
+
self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
|
151 |
+
# (fg_v * div_global_v + (1-fg_v) * div_v)
|
152 |
+
scaled_var = zeroed_mean * \
|
153 |
+
(self.bg_global_var/(self.fg_global_var + self.eps))
|
154 |
+
normalized_lg = scaled_var + \
|
155 |
+
(self.bg_local_mean *
|
156 |
+
self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
|
157 |
+
|
158 |
+
return normalized_lg
|
159 |
+
|
160 |
+
|
161 |
+
class PatchNormalizer(nn.Module):
|
162 |
+
def __init__(self, in_channels=3, eps=1e-7, grid_count=1, weights=[0.5, 0.5], init_value=1e-2):
|
163 |
+
super(PatchNormalizer, self).__init__()
|
164 |
+
self.grid_count = grid_count
|
165 |
+
self.eps = eps
|
166 |
+
|
167 |
+
self.weights = nn.Parameter(
|
168 |
+
torch.FloatTensor(weights), requires_grad=True)
|
169 |
+
self.fg_var = nn.Parameter(
|
170 |
+
init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
|
171 |
+
self.fg_bias = nn.Parameter(
|
172 |
+
init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
|
173 |
+
self.patched_fg_var = nn.Parameter(
|
174 |
+
init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
|
175 |
+
self.patched_fg_bias = nn.Parameter(
|
176 |
+
init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
|
177 |
+
self.bg_var = nn.Parameter(
|
178 |
+
init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
|
179 |
+
self.bg_bias = nn.Parameter(
|
180 |
+
init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
|
181 |
+
self.grid_weights = torch.nn.Parameter(torch.ones((in_channels, grid_count, grid_count))[
|
182 |
+
None, :, :, :] / (grid_count*grid_count*in_channels), requires_grad=True)
|
183 |
+
|
184 |
+
def local_normalization(self, value):
|
185 |
+
zeroed_mean = value - \
|
186 |
+
(self.fg_local_mean *
|
187 |
+
self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
|
188 |
+
# (fg_v * div_global_v + (1-fg_v) * div_v)
|
189 |
+
scaled_var = zeroed_mean * \
|
190 |
+
(self.bg_global_var/(self.fg_global_var + self.eps))
|
191 |
+
normalized_lg = scaled_var + \
|
192 |
+
(self.bg_local_mean *
|
193 |
+
self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
|
194 |
+
|
195 |
+
return normalized_lg
|
196 |
+
|
197 |
+
def get_mean_std(self, img, mask, dim=[2, 3]):
|
198 |
+
sum = torch.sum(img*mask, dim=dim) # (B, C)
|
199 |
+
num = torch.sum(mask, dim=dim) # (B, C)
|
200 |
+
mu = sum / (num + self.eps)
|
201 |
+
mean = mu[:, :, None, None]
|
202 |
+
var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps)
|
203 |
+
var = var[:, :, None, None]
|
204 |
+
|
205 |
+
return mean, torch.sqrt(var+self.eps)
|
206 |
+
|
207 |
+
def compute_patch_statistics(self, img, mask):
|
208 |
+
means, stds = [], []
|
209 |
+
bs, dx, dy = img.shape[0], img.shape[2] // self.grid_count, img.shape[3] // self.grid_count
|
210 |
+
for h in range(self.grid_count):
|
211 |
+
cmeans, cstds = [], []
|
212 |
+
for w in range(self.grid_count):
|
213 |
+
ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
|
214 |
+
if h == self.grid_count - 1:
|
215 |
+
ind[1] = None
|
216 |
+
if w == self.grid_count - 1:
|
217 |
+
ind[-1] = None
|
218 |
+
m, v = self.get_mean_std(
|
219 |
+
img[:, :, ind[0]:ind[1], ind[2]:ind[3]], mask[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3])
|
220 |
+
cmeans.append(m.reshape(m.shape[:2]))
|
221 |
+
cstds.append(v.reshape(v.shape[:2]))
|
222 |
+
means.append(torch.stack(cmeans))
|
223 |
+
stds.append(torch.stack(cstds))
|
224 |
+
|
225 |
+
return torch.stack(means), torch.stack(stds)
|
226 |
+
|
227 |
+
def compute_mean_var(self, x, dim=[2, 3]):
|
228 |
+
mean = x.float().mean(dim=dim)
|
229 |
+
var = torch.sqrt(x.float().var(dim=dim))
|
230 |
+
|
231 |
+
return mean, var
|
232 |
+
|
233 |
+
def forward(self, fg, bg, mask):
|
234 |
+
|
235 |
+
self.local_means, self.local_vars = self.compute_patch_statistics(
|
236 |
+
bg, (1-mask))
|
237 |
+
|
238 |
+
bg_mean, bg_var = self.get_mean_std(bg, 1 - mask)
|
239 |
+
zeroed_mean = (bg - bg_mean)
|
240 |
+
unscaled = zeroed_mean / bg_var
|
241 |
+
bg_normalized = unscaled * self.bg_var + self.bg_bias
|
242 |
+
|
243 |
+
fg_mean, fg_var = self.get_mean_std(fg, mask)
|
244 |
+
zeroed_mean = fg - fg_mean
|
245 |
+
unscaled = zeroed_mean / fg_var
|
246 |
+
|
247 |
+
mean_patched_back = (self.local_means.permute(
|
248 |
+
2, 3, 0, 1)*self.grid_weights).sum(dim=[2, 3])[:, :, None, None]
|
249 |
+
|
250 |
+
normalized = unscaled * bg_var + bg_mean
|
251 |
+
patch_normalized = unscaled * bg_var + mean_patched_back
|
252 |
+
|
253 |
+
fg_normalized = normalized * self.fg_var + self.fg_bias
|
254 |
+
fg_patch_normalized = patch_normalized * \
|
255 |
+
self.patched_fg_var + self.patched_fg_bias
|
256 |
+
|
257 |
+
fg_result = self.weights[0] * fg_normalized + \
|
258 |
+
self.weights[1] * fg_patch_normalized
|
259 |
+
composite = blend(fg_result, bg_normalized, mask)
|
260 |
+
|
261 |
+
return composite
|
tools/stylematte.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from typing import List
|
8 |
+
from itertools import chain
|
9 |
+
|
10 |
+
from transformers import SegformerForSemanticSegmentation, Mask2FormerForUniversalSegmentation
|
11 |
+
device = 'cpu'
|
12 |
+
|
13 |
+
|
14 |
+
class EncoderDecoder(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
encoder,
|
18 |
+
decoder,
|
19 |
+
prefix=nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True),
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.encoder = encoder
|
23 |
+
self.decoder = decoder
|
24 |
+
self.prefix = prefix
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
if self.prefix is not None:
|
28 |
+
x = self.prefix(x)
|
29 |
+
x = self.encoder(x)["hidden_states"] # transformers
|
30 |
+
return self.decoder(x)
|
31 |
+
|
32 |
+
|
33 |
+
def conv2d_relu(input_filters, output_filters, kernel_size=3, bias=True):
|
34 |
+
return nn.Sequential(
|
35 |
+
nn.Conv2d(input_filters, output_filters,
|
36 |
+
kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
|
37 |
+
nn.LeakyReLU(0.2, inplace=True),
|
38 |
+
nn.BatchNorm2d(output_filters)
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def up_and_add(x, y):
|
43 |
+
return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
|
44 |
+
|
45 |
+
|
46 |
+
class FPN_fuse(nn.Module):
|
47 |
+
def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
|
48 |
+
super(FPN_fuse, self).__init__()
|
49 |
+
assert feature_channels[0] == fpn_out
|
50 |
+
self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
|
51 |
+
for ft_size in feature_channels[1:]])
|
52 |
+
self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
|
53 |
+
* (len(feature_channels)-1))
|
54 |
+
self.conv_fusion = nn.Sequential(
|
55 |
+
nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3,
|
56 |
+
padding=1, bias=False),
|
57 |
+
nn.BatchNorm2d(fpn_out),
|
58 |
+
nn.ReLU(inplace=True),
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, features):
|
62 |
+
|
63 |
+
features[:-1] = [conv1x1(feature) for feature,
|
64 |
+
conv1x1 in zip(features[:-1], self.conv1x1)]
|
65 |
+
feature = up_and_add(self.smooth_conv[0](features[0]), features[1])
|
66 |
+
feature = up_and_add(self.smooth_conv[1](feature), features[2])
|
67 |
+
feature = up_and_add(self.smooth_conv[2](feature), features[3])
|
68 |
+
|
69 |
+
H, W = features[-1].size(2), features[-1].size(3)
|
70 |
+
x = [feature, features[-1]]
|
71 |
+
x = [F.interpolate(x_el, size=(H, W), mode='bilinear',
|
72 |
+
align_corners=True) for x_el in x]
|
73 |
+
|
74 |
+
x = self.conv_fusion(torch.cat(x, dim=1))
|
75 |
+
# x = F.interpolate(x, size=(H*4, W*4), mode='bilinear', align_corners=True)
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
class PSPModule(nn.Module):
|
80 |
+
# In the original inmplementation they use precise RoI pooling
|
81 |
+
# Instead of using adaptative average pooling
|
82 |
+
def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
|
83 |
+
super(PSPModule, self).__init__()
|
84 |
+
out_channels = in_channels // len(bin_sizes)
|
85 |
+
self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
|
86 |
+
for b_s in bin_sizes])
|
87 |
+
self.bottleneck = nn.Sequential(
|
88 |
+
nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
|
89 |
+
kernel_size=3, padding=1, bias=False),
|
90 |
+
nn.BatchNorm2d(in_channels),
|
91 |
+
nn.ReLU(inplace=True),
|
92 |
+
nn.Dropout2d(0.1)
|
93 |
+
)
|
94 |
+
|
95 |
+
def _make_stages(self, in_channels, out_channels, bin_sz):
|
96 |
+
prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
|
97 |
+
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
98 |
+
bn = nn.BatchNorm2d(out_channels)
|
99 |
+
relu = nn.ReLU(inplace=True)
|
100 |
+
return nn.Sequential(prior, conv, bn, relu)
|
101 |
+
|
102 |
+
def forward(self, features):
|
103 |
+
h, w = features.size()[2], features.size()[3]
|
104 |
+
pyramids = [features]
|
105 |
+
pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
|
106 |
+
align_corners=True) for stage in self.stages])
|
107 |
+
output = self.bottleneck(torch.cat(pyramids, dim=1))
|
108 |
+
return output
|
109 |
+
|
110 |
+
|
111 |
+
class UperNet_swin(nn.Module):
|
112 |
+
# Implementing only the object path
|
113 |
+
def __init__(self, backbone, pretrained=True):
|
114 |
+
super(UperNet_swin, self).__init__()
|
115 |
+
|
116 |
+
self.backbone = backbone
|
117 |
+
feature_channels = [192, 384, 768, 768]
|
118 |
+
self.PPN = PSPModule(feature_channels[-1])
|
119 |
+
self.FPN = FPN_fuse(feature_channels, fpn_out=feature_channels[0])
|
120 |
+
self.head = nn.Conv2d(feature_channels[0], 1, kernel_size=3, padding=1)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
input_size = (x.size()[2], x.size()[3])
|
124 |
+
features = self.backbone(x)["hidden_states"]
|
125 |
+
features[-1] = self.PPN(features[-1])
|
126 |
+
x = self.head(self.FPN(features))
|
127 |
+
|
128 |
+
x = F.interpolate(x, size=input_size, mode='bilinear')
|
129 |
+
return x
|
130 |
+
|
131 |
+
def get_backbone_params(self):
|
132 |
+
return self.backbone.parameters()
|
133 |
+
|
134 |
+
def get_decoder_params(self):
|
135 |
+
return chain(self.PPN.parameters(), self.FPN.parameters(), self.head.parameters())
|
136 |
+
|
137 |
+
|
138 |
+
class UnetDecoder(nn.Module):
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
encoder_channels=(3, 192, 384, 768, 768),
|
142 |
+
decoder_channels=(512, 256, 128, 64),
|
143 |
+
n_blocks=4,
|
144 |
+
use_batchnorm=True,
|
145 |
+
attention_type=None,
|
146 |
+
center=False,
|
147 |
+
):
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
if n_blocks != len(decoder_channels):
|
151 |
+
raise ValueError(
|
152 |
+
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
|
153 |
+
n_blocks, len(decoder_channels)
|
154 |
+
)
|
155 |
+
)
|
156 |
+
|
157 |
+
# remove first skip with same spatial resolution
|
158 |
+
encoder_channels = encoder_channels[1:]
|
159 |
+
# reverse channels to start from head of encoder
|
160 |
+
encoder_channels = encoder_channels[::-1]
|
161 |
+
|
162 |
+
# computing blocks input and output channels
|
163 |
+
head_channels = encoder_channels[0]
|
164 |
+
in_channels = [head_channels] + list(decoder_channels[:-1])
|
165 |
+
skip_channels = list(encoder_channels[1:]) + [0]
|
166 |
+
|
167 |
+
out_channels = decoder_channels
|
168 |
+
|
169 |
+
if center:
|
170 |
+
self.center = CenterBlock(
|
171 |
+
head_channels, head_channels, use_batchnorm=use_batchnorm)
|
172 |
+
else:
|
173 |
+
self.center = nn.Identity()
|
174 |
+
|
175 |
+
# combine decoder keyword arguments
|
176 |
+
kwargs = dict(use_batchnorm=use_batchnorm,
|
177 |
+
attention_type=attention_type)
|
178 |
+
blocks = [
|
179 |
+
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
|
180 |
+
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
|
181 |
+
]
|
182 |
+
self.blocks = nn.ModuleList(blocks)
|
183 |
+
upscale_factor = 4
|
184 |
+
self.matting_head = nn.Sequential(
|
185 |
+
nn.Conv2d(64, 1, kernel_size=3, padding=1),
|
186 |
+
nn.ReLU(),
|
187 |
+
nn.UpsamplingBilinear2d(scale_factor=upscale_factor),
|
188 |
+
)
|
189 |
+
|
190 |
+
def preprocess_features(self, x):
|
191 |
+
features = []
|
192 |
+
for out_tensor in x:
|
193 |
+
bs, n, f = out_tensor.size()
|
194 |
+
h = int(n**0.5)
|
195 |
+
feature = out_tensor.view(-1, h, h,
|
196 |
+
f).permute(0, 3, 1, 2).contiguous()
|
197 |
+
features.append(feature)
|
198 |
+
return features
|
199 |
+
|
200 |
+
def forward(self, features):
|
201 |
+
# remove first skip with same spatial resolution
|
202 |
+
features = features[1:]
|
203 |
+
# reverse channels to start from head of encoder
|
204 |
+
features = features[::-1]
|
205 |
+
|
206 |
+
features = self.preprocess_features(features)
|
207 |
+
|
208 |
+
head = features[0]
|
209 |
+
skips = features[1:]
|
210 |
+
|
211 |
+
x = self.center(head)
|
212 |
+
for i, decoder_block in enumerate(self.blocks):
|
213 |
+
skip = skips[i] if i < len(skips) else None
|
214 |
+
x = decoder_block(x, skip)
|
215 |
+
# y_i = self.upsample1(y_i)
|
216 |
+
# hypercol = torch.cat([y0,y1,y2,y3,y4], dim=1)
|
217 |
+
x = self.matting_head(x)
|
218 |
+
x = 1-nn.ReLU()(1-x)
|
219 |
+
return x
|
220 |
+
|
221 |
+
|
222 |
+
class SegmentationHead(nn.Sequential):
|
223 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
|
224 |
+
conv2d = nn.Conv2d(in_channels, out_channels,
|
225 |
+
kernel_size=kernel_size, padding=kernel_size // 2)
|
226 |
+
upsampling = nn.UpsamplingBilinear2d(
|
227 |
+
scale_factor=upsampling) if upsampling > 1 else nn.Identity()
|
228 |
+
super().__init__(conv2d, upsampling)
|
229 |
+
|
230 |
+
|
231 |
+
class DecoderBlock(nn.Module):
|
232 |
+
def __init__(
|
233 |
+
self,
|
234 |
+
in_channels,
|
235 |
+
skip_channels,
|
236 |
+
out_channels,
|
237 |
+
use_batchnorm=True,
|
238 |
+
attention_type=None,
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
self.conv1 = conv2d_relu(
|
242 |
+
in_channels + skip_channels,
|
243 |
+
out_channels,
|
244 |
+
kernel_size=3
|
245 |
+
)
|
246 |
+
self.conv2 = conv2d_relu(
|
247 |
+
out_channels,
|
248 |
+
out_channels,
|
249 |
+
kernel_size=3,
|
250 |
+
)
|
251 |
+
self.in_channels = in_channels
|
252 |
+
self.out_channels = out_channels
|
253 |
+
self.skip_channels = skip_channels
|
254 |
+
|
255 |
+
def forward(self, x, skip=None):
|
256 |
+
if skip is None:
|
257 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
258 |
+
else:
|
259 |
+
if x.shape[-1] != skip.shape[-1]:
|
260 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
261 |
+
if skip is not None:
|
262 |
+
# print(x.shape,skip.shape)
|
263 |
+
x = torch.cat([x, skip], dim=1)
|
264 |
+
x = self.conv1(x)
|
265 |
+
x = self.conv2(x)
|
266 |
+
return x
|
267 |
+
|
268 |
+
|
269 |
+
class CenterBlock(nn.Sequential):
|
270 |
+
def __init__(self, in_channels, out_channels):
|
271 |
+
conv1 = conv2d_relu(
|
272 |
+
in_channels,
|
273 |
+
out_channels,
|
274 |
+
kernel_size=3,
|
275 |
+
)
|
276 |
+
conv2 = conv2d_relu(
|
277 |
+
out_channels,
|
278 |
+
out_channels,
|
279 |
+
kernel_size=3,
|
280 |
+
)
|
281 |
+
super().__init__(conv1, conv2)
|
282 |
+
|
283 |
+
|
284 |
+
class SegForm(nn.Module):
|
285 |
+
def __init__(self):
|
286 |
+
super(SegForm, self).__init__()
|
287 |
+
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
288 |
+
# configuration.num_labels = 1 ## set output as 1
|
289 |
+
# self.model = SegformerForSemanticSegmentation(config=configuration)
|
290 |
+
|
291 |
+
self.model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
|
292 |
+
)
|
293 |
+
|
294 |
+
def forward(self, image):
|
295 |
+
img_segs = self.model(image)
|
296 |
+
upsampled_logits = nn.functional.interpolate(img_segs.logits,
|
297 |
+
scale_factor=4,
|
298 |
+
mode='nearest',
|
299 |
+
)
|
300 |
+
return upsampled_logits
|
301 |
+
|
302 |
+
|
303 |
+
class StyleMatte(nn.Module):
|
304 |
+
def __init__(self):
|
305 |
+
super(StyleMatte, self).__init__()
|
306 |
+
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
307 |
+
# configuration.num_labels = 1 ## set output as 1
|
308 |
+
self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
|
309 |
+
self.pixel_decoder = Mask2FormerForUniversalSegmentation.from_pretrained(
|
310 |
+
"facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
|
311 |
+
self.fgf = FastGuidedFilter()
|
312 |
+
self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
|
313 |
+
# self.mean = torch.Tensor([0.43216, 0.394666, 0.37645]).float().view(-1, 1, 1)
|
314 |
+
# self.register_buffer('image_net_mean', self.mean)
|
315 |
+
# self.std = torch.Tensor([0.22803, 0.22145, 0.216989]).float().view(-1, 1, 1)
|
316 |
+
# self.register_buffer('image_net_std', self.std)
|
317 |
+
|
318 |
+
def forward(self, image, normalize=False):
|
319 |
+
# if normalize:
|
320 |
+
# image.sub_(self.get_buffer("image_net_mean")).div_(self.get_buffer("image_net_std"))
|
321 |
+
|
322 |
+
decoder_out = self.pixel_decoder(image)
|
323 |
+
decoder_states = list(decoder_out.decoder_hidden_states)
|
324 |
+
decoder_states.append(decoder_out.decoder_last_hidden_state)
|
325 |
+
out_pure = self.fpn(decoder_states)
|
326 |
+
|
327 |
+
image_lr = nn.functional.interpolate(image.mean(1, keepdim=True),
|
328 |
+
scale_factor=0.25,
|
329 |
+
mode='bicubic',
|
330 |
+
align_corners=True
|
331 |
+
)
|
332 |
+
out = self.conv(out_pure)
|
333 |
+
out = self.fgf(image_lr, out, image.mean(
|
334 |
+
1, keepdim=True)) # .clip(0,1)
|
335 |
+
# out = nn.Sigmoid()(out)
|
336 |
+
# out = nn.functional.interpolate(out,
|
337 |
+
# scale_factor=4,
|
338 |
+
# mode='bicubic',
|
339 |
+
# align_corners=True
|
340 |
+
# )
|
341 |
+
|
342 |
+
return torch.sigmoid(out)
|
343 |
+
|
344 |
+
def get_training_params(self):
|
345 |
+
# +list(self.fgf.parameters())
|
346 |
+
return list(self.fpn.parameters())+list(self.conv.parameters())
|
347 |
+
|
348 |
+
|
349 |
+
class GuidedFilter(nn.Module):
|
350 |
+
def __init__(self, r, eps=1e-8):
|
351 |
+
super(GuidedFilter, self).__init__()
|
352 |
+
|
353 |
+
self.r = r
|
354 |
+
self.eps = eps
|
355 |
+
self.boxfilter = BoxFilter(r)
|
356 |
+
|
357 |
+
def forward(self, x, y):
|
358 |
+
n_x, c_x, h_x, w_x = x.size()
|
359 |
+
n_y, c_y, h_y, w_y = y.size()
|
360 |
+
|
361 |
+
assert n_x == n_y
|
362 |
+
assert c_x == 1 or c_x == c_y
|
363 |
+
assert h_x == h_y and w_x == w_y
|
364 |
+
assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
|
365 |
+
|
366 |
+
# N
|
367 |
+
N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
|
368 |
+
|
369 |
+
# mean_x
|
370 |
+
mean_x = self.boxfilter(x) / N
|
371 |
+
# mean_y
|
372 |
+
mean_y = self.boxfilter(y) / N
|
373 |
+
# cov_xy
|
374 |
+
cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
|
375 |
+
# var_x
|
376 |
+
var_x = self.boxfilter(x * x) / N - mean_x * mean_x
|
377 |
+
|
378 |
+
# A
|
379 |
+
A = cov_xy / (var_x + self.eps)
|
380 |
+
# b
|
381 |
+
b = mean_y - A * mean_x
|
382 |
+
|
383 |
+
# mean_A; mean_b
|
384 |
+
mean_A = self.boxfilter(A) / N
|
385 |
+
mean_b = self.boxfilter(b) / N
|
386 |
+
|
387 |
+
return mean_A * x + mean_b
|
388 |
+
|
389 |
+
|
390 |
+
class FastGuidedFilter(nn.Module):
|
391 |
+
def __init__(self, r=1, eps=1e-8):
|
392 |
+
super(FastGuidedFilter, self).__init__()
|
393 |
+
|
394 |
+
self.r = r
|
395 |
+
self.eps = eps
|
396 |
+
self.boxfilter = BoxFilter(r)
|
397 |
+
|
398 |
+
def forward(self, lr_x, lr_y, hr_x):
|
399 |
+
n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
|
400 |
+
n_lry, c_lry, h_lry, w_lry = lr_y.size()
|
401 |
+
n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
|
402 |
+
|
403 |
+
assert n_lrx == n_lry and n_lry == n_hrx
|
404 |
+
assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
|
405 |
+
assert h_lrx == h_lry and w_lrx == w_lry
|
406 |
+
assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
|
407 |
+
|
408 |
+
# N
|
409 |
+
N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))
|
410 |
+
|
411 |
+
# mean_x
|
412 |
+
mean_x = self.boxfilter(lr_x) / N
|
413 |
+
# mean_y
|
414 |
+
mean_y = self.boxfilter(lr_y) / N
|
415 |
+
# cov_xy
|
416 |
+
cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
|
417 |
+
# var_x
|
418 |
+
var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
|
419 |
+
|
420 |
+
# A
|
421 |
+
A = cov_xy / (var_x + self.eps)
|
422 |
+
# b
|
423 |
+
b = mean_y - A * mean_x
|
424 |
+
|
425 |
+
# mean_A; mean_b
|
426 |
+
mean_A = F.interpolate(
|
427 |
+
A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
|
428 |
+
mean_b = F.interpolate(
|
429 |
+
b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
|
430 |
+
|
431 |
+
return mean_A*hr_x+mean_b
|
432 |
+
|
433 |
+
|
434 |
+
class DeepGuidedFilterRefiner(nn.Module):
|
435 |
+
def __init__(self, hid_channels=16):
|
436 |
+
super().__init__()
|
437 |
+
self.box_filter = nn.Conv2d(
|
438 |
+
4, 4, kernel_size=3, padding=1, bias=False, groups=4)
|
439 |
+
self.box_filter.weight.data[...] = 1 / 9
|
440 |
+
self.conv = nn.Sequential(
|
441 |
+
nn.Conv2d(4 * 2 + hid_channels, hid_channels,
|
442 |
+
kernel_size=1, bias=False),
|
443 |
+
nn.BatchNorm2d(hid_channels),
|
444 |
+
nn.ReLU(True),
|
445 |
+
nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
|
446 |
+
nn.BatchNorm2d(hid_channels),
|
447 |
+
nn.ReLU(True),
|
448 |
+
nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
|
449 |
+
)
|
450 |
+
|
451 |
+
def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
|
452 |
+
fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
|
453 |
+
base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
|
454 |
+
base_y = torch.cat([base_fgr, base_pha], dim=1)
|
455 |
+
|
456 |
+
mean_x = self.box_filter(base_x)
|
457 |
+
mean_y = self.box_filter(base_y)
|
458 |
+
cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
|
459 |
+
var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
|
460 |
+
|
461 |
+
A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
|
462 |
+
b = mean_y - A * mean_x
|
463 |
+
|
464 |
+
H, W = fine_src.shape[2:]
|
465 |
+
A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
|
466 |
+
b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
|
467 |
+
|
468 |
+
out = A * fine_x + b
|
469 |
+
fgr, pha = out.split([3, 1], dim=1)
|
470 |
+
return fgr, pha
|
471 |
+
|
472 |
+
|
473 |
+
def diff_x(input, r):
|
474 |
+
assert input.dim() == 4
|
475 |
+
|
476 |
+
left = input[:, :, r:2 * r + 1]
|
477 |
+
middle = input[:, :, 2 * r + 1:] - input[:, :, :-2 * r - 1]
|
478 |
+
right = input[:, :, -1:] - input[:, :, -2 * r - 1: -r - 1]
|
479 |
+
|
480 |
+
output = torch.cat([left, middle, right], dim=2)
|
481 |
+
|
482 |
+
return output
|
483 |
+
|
484 |
+
|
485 |
+
def diff_y(input, r):
|
486 |
+
assert input.dim() == 4
|
487 |
+
|
488 |
+
left = input[:, :, :, r:2 * r + 1]
|
489 |
+
middle = input[:, :, :, 2 * r + 1:] - input[:, :, :, :-2 * r - 1]
|
490 |
+
right = input[:, :, :, -1:] - input[:, :, :, -2 * r - 1: -r - 1]
|
491 |
+
|
492 |
+
output = torch.cat([left, middle, right], dim=3)
|
493 |
+
|
494 |
+
return output
|
495 |
+
|
496 |
+
|
497 |
+
class BoxFilter(nn.Module):
|
498 |
+
def __init__(self, r):
|
499 |
+
super(BoxFilter, self).__init__()
|
500 |
+
|
501 |
+
self.r = r
|
502 |
+
|
503 |
+
def forward(self, x):
|
504 |
+
assert x.dim() == 4
|
505 |
+
|
506 |
+
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
|
tools/util.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
from typing import Tuple
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchvision.utils import make_grid
|
7 |
+
import cv2
|
8 |
+
from torchvision import transforms, models
|
9 |
+
|
10 |
+
|
11 |
+
def log(msg, lvl='info'):
|
12 |
+
if lvl == 'info':
|
13 |
+
print(f"***********{msg}****************")
|
14 |
+
if lvl == 'error':
|
15 |
+
print(f"!!! Exception: {msg} !!!")
|
16 |
+
|
17 |
+
|
18 |
+
def lab_shift(x, invert=False):
|
19 |
+
x = x.float()
|
20 |
+
if invert:
|
21 |
+
x[:, 0, :, :] /= 2.55
|
22 |
+
x[:, 1, :, :] -= 128
|
23 |
+
x[:, 2, :, :] -= 128
|
24 |
+
else:
|
25 |
+
x[:, 0, :, :] *= 2.55
|
26 |
+
x[:, 1, :, :] += 128
|
27 |
+
x[:, 2, :, :] += 128
|
28 |
+
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
def calculate_psnr(img1, img2):
|
33 |
+
# img1 and img2 have range [0, 255]
|
34 |
+
img1 = img1.astype(np.float64)
|
35 |
+
img2 = img2.astype(np.float64)
|
36 |
+
mse = np.mean((img1 - img2)**2)
|
37 |
+
if mse == 0:
|
38 |
+
return float('inf')
|
39 |
+
|
40 |
+
return 20 * math.log10(255.0 / math.sqrt(mse))
|
41 |
+
|
42 |
+
|
43 |
+
def calculate_fpsnr(fmse):
|
44 |
+
return 10 * math.log10(255.0 / (fmse + 1e-8))
|
45 |
+
|
46 |
+
|
47 |
+
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1), bit=8):
|
48 |
+
'''
|
49 |
+
Converts a torch Tensor into an image Numpy array
|
50 |
+
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
51 |
+
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
52 |
+
'''
|
53 |
+
norm = float(2**bit) - 1
|
54 |
+
# print('before', tensor[:,:,0].max(), tensor[:,:,0].min(), '\t', tensor[:,:,1].max(), tensor[:,:,1].min(), '\t', tensor[:,:,2].max(), tensor[:,:,2].min())
|
55 |
+
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
|
56 |
+
# print('clamp ', tensor[:,:,0].max(), tensor[:,:,0].min(), '\t', tensor[:,:,1].max(), tensor[:,:,1].min(), '\t', tensor[:,:,2].max(), tensor[:,:,2].min())
|
57 |
+
tensor = (tensor - min_max[0]) / \
|
58 |
+
(min_max[1] - min_max[0]) # to range [0,1]
|
59 |
+
n_dim = tensor.dim()
|
60 |
+
if n_dim == 4:
|
61 |
+
n_img = len(tensor)
|
62 |
+
img_np = make_grid(tensor, nrow=int(
|
63 |
+
math.sqrt(n_img)), normalize=False).numpy()
|
64 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
65 |
+
elif n_dim == 3:
|
66 |
+
img_np = tensor.numpy()
|
67 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
68 |
+
elif n_dim == 2:
|
69 |
+
img_np = tensor.numpy()
|
70 |
+
else:
|
71 |
+
raise TypeError(
|
72 |
+
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
73 |
+
if out_type == np.uint8:
|
74 |
+
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
75 |
+
img_np = (img_np * norm).round()
|
76 |
+
return img_np.astype(out_type)
|
77 |
+
|
78 |
+
|
79 |
+
def rgb_to_lab(image: torch.Tensor) -> torch.Tensor:
|
80 |
+
r"""Convert a RGB image to Lab.
|
81 |
+
|
82 |
+
.. image:: _static/img/rgb_to_lab.png
|
83 |
+
|
84 |
+
The input RGB image is assumed to be in the range of :math:`[0, 1]`. Lab
|
85 |
+
color is computed using the D65 illuminant and Observer 2.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
image: RGB Image to be converted to Lab with shape :math:`(*, 3, H, W)`.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Lab version of the image with shape :math:`(*, 3, H, W)`.
|
92 |
+
The L channel values are in the range 0..100. a and b are in the range -128..127.
|
93 |
+
|
94 |
+
Example:
|
95 |
+
>>> input = torch.rand(2, 3, 4, 5)
|
96 |
+
>>> output = rgb_to_lab(input) # 2x3x4x5
|
97 |
+
"""
|
98 |
+
if not isinstance(image, torch.Tensor):
|
99 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
|
100 |
+
|
101 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
102 |
+
raise ValueError(
|
103 |
+
f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
|
104 |
+
|
105 |
+
# Convert from sRGB to Linear RGB
|
106 |
+
lin_rgb = rgb_to_linear_rgb(image)
|
107 |
+
|
108 |
+
xyz_im: torch.Tensor = rgb_to_xyz(lin_rgb)
|
109 |
+
|
110 |
+
# normalize for D65 white point
|
111 |
+
xyz_ref_white = torch.tensor(
|
112 |
+
[0.95047, 1.0, 1.08883], device=xyz_im.device, dtype=xyz_im.dtype)[..., :, None, None]
|
113 |
+
xyz_normalized = torch.div(xyz_im, xyz_ref_white)
|
114 |
+
|
115 |
+
threshold = 0.008856
|
116 |
+
power = torch.pow(xyz_normalized.clamp(min=threshold), 1 / 3.0)
|
117 |
+
scale = 7.787 * xyz_normalized + 4.0 / 29.0
|
118 |
+
xyz_int = torch.where(xyz_normalized > threshold, power, scale)
|
119 |
+
|
120 |
+
x: torch.Tensor = xyz_int[..., 0, :, :]
|
121 |
+
y: torch.Tensor = xyz_int[..., 1, :, :]
|
122 |
+
z: torch.Tensor = xyz_int[..., 2, :, :]
|
123 |
+
|
124 |
+
L: torch.Tensor = (116.0 * y) - 16.0
|
125 |
+
a: torch.Tensor = 500.0 * (x - y)
|
126 |
+
_b: torch.Tensor = 200.0 * (y - z)
|
127 |
+
|
128 |
+
out: torch.Tensor = torch.stack([L, a, _b], dim=-3)
|
129 |
+
|
130 |
+
return out
|
131 |
+
|
132 |
+
|
133 |
+
def lab_to_rgb(image: torch.Tensor, clip: bool = True) -> torch.Tensor:
|
134 |
+
r"""Convert a Lab image to RGB.
|
135 |
+
|
136 |
+
The L channel is assumed to be in the range of :math:`[0, 100]`.
|
137 |
+
a and b channels are in the range of :math:`[-128, 127]`.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
image: Lab image to be converted to RGB with shape :math:`(*, 3, H, W)`.
|
141 |
+
clip: Whether to apply clipping to insure output RGB values in range :math:`[0, 1]`.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Lab version of the image with shape :math:`(*, 3, H, W)`.
|
145 |
+
The output RGB image are in the range of :math:`[0, 1]`.
|
146 |
+
|
147 |
+
Example:
|
148 |
+
>>> input = torch.rand(2, 3, 4, 5)
|
149 |
+
>>> output = lab_to_rgb(input) # 2x3x4x5
|
150 |
+
"""
|
151 |
+
if not isinstance(image, torch.Tensor):
|
152 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
|
153 |
+
|
154 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
155 |
+
raise ValueError(
|
156 |
+
f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
|
157 |
+
|
158 |
+
L: torch.Tensor = image[..., 0, :, :]
|
159 |
+
a: torch.Tensor = image[..., 1, :, :]
|
160 |
+
_b: torch.Tensor = image[..., 2, :, :]
|
161 |
+
|
162 |
+
fy = (L + 16.0) / 116.0
|
163 |
+
fx = (a / 500.0) + fy
|
164 |
+
fz = fy - (_b / 200.0)
|
165 |
+
|
166 |
+
# if color data out of range: Z < 0
|
167 |
+
fz = fz.clamp(min=0.0)
|
168 |
+
|
169 |
+
fxyz = torch.stack([fx, fy, fz], dim=-3)
|
170 |
+
|
171 |
+
# Convert from Lab to XYZ
|
172 |
+
power = torch.pow(fxyz, 3.0)
|
173 |
+
scale = (fxyz - 4.0 / 29.0) / 7.787
|
174 |
+
xyz = torch.where(fxyz > 0.2068966, power, scale)
|
175 |
+
|
176 |
+
# For D65 white point
|
177 |
+
xyz_ref_white = torch.tensor(
|
178 |
+
[0.95047, 1.0, 1.08883], device=xyz.device, dtype=xyz.dtype)[..., :, None, None]
|
179 |
+
xyz_im = xyz * xyz_ref_white
|
180 |
+
|
181 |
+
rgbs_im: torch.Tensor = xyz_to_rgb(xyz_im)
|
182 |
+
|
183 |
+
# https://github.com/richzhang/colorization-pytorch/blob/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/util.py#L107
|
184 |
+
# rgbs_im = torch.where(rgbs_im < 0, torch.zeros_like(rgbs_im), rgbs_im)
|
185 |
+
|
186 |
+
# Convert from RGB Linear to sRGB
|
187 |
+
rgb_im = linear_rgb_to_rgb(rgbs_im)
|
188 |
+
|
189 |
+
# Clip to 0,1 https://www.w3.org/Graphics/Color/srgb
|
190 |
+
if clip:
|
191 |
+
rgb_im = torch.clamp(rgb_im, min=0.0, max=1.0)
|
192 |
+
|
193 |
+
return rgb_im
|
194 |
+
|
195 |
+
|
196 |
+
def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor:
|
197 |
+
r"""Convert a RGB image to XYZ.
|
198 |
+
|
199 |
+
.. image:: _static/img/rgb_to_xyz.png
|
200 |
+
|
201 |
+
Args:
|
202 |
+
image: RGB Image to be converted to XYZ with shape :math:`(*, 3, H, W)`.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
XYZ version of the image with shape :math:`(*, 3, H, W)`.
|
206 |
+
|
207 |
+
Example:
|
208 |
+
>>> input = torch.rand(2, 3, 4, 5)
|
209 |
+
>>> output = rgb_to_xyz(input) # 2x3x4x5
|
210 |
+
"""
|
211 |
+
if not isinstance(image, torch.Tensor):
|
212 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
|
213 |
+
|
214 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
215 |
+
raise ValueError(
|
216 |
+
f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
|
217 |
+
|
218 |
+
r: torch.Tensor = image[..., 0, :, :]
|
219 |
+
g: torch.Tensor = image[..., 1, :, :]
|
220 |
+
b: torch.Tensor = image[..., 2, :, :]
|
221 |
+
|
222 |
+
x: torch.Tensor = 0.412453 * r + 0.357580 * g + 0.180423 * b
|
223 |
+
y: torch.Tensor = 0.212671 * r + 0.715160 * g + 0.072169 * b
|
224 |
+
z: torch.Tensor = 0.019334 * r + 0.119193 * g + 0.950227 * b
|
225 |
+
|
226 |
+
out: torch.Tensor = torch.stack([x, y, z], -3)
|
227 |
+
|
228 |
+
return out
|
229 |
+
|
230 |
+
|
231 |
+
def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor:
|
232 |
+
r"""Convert a XYZ image to RGB.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
image: XYZ Image to be converted to RGB with shape :math:`(*, 3, H, W)`.
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
RGB version of the image with shape :math:`(*, 3, H, W)`.
|
239 |
+
|
240 |
+
Example:
|
241 |
+
>>> input = torch.rand(2, 3, 4, 5)
|
242 |
+
>>> output = xyz_to_rgb(input) # 2x3x4x5
|
243 |
+
"""
|
244 |
+
if not isinstance(image, torch.Tensor):
|
245 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
|
246 |
+
|
247 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
248 |
+
raise ValueError(
|
249 |
+
f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
|
250 |
+
|
251 |
+
x: torch.Tensor = image[..., 0, :, :]
|
252 |
+
y: torch.Tensor = image[..., 1, :, :]
|
253 |
+
z: torch.Tensor = image[..., 2, :, :]
|
254 |
+
|
255 |
+
r: torch.Tensor = 3.2404813432005266 * x + - \
|
256 |
+
1.5371515162713185 * y + -0.4985363261688878 * z
|
257 |
+
g: torch.Tensor = -0.9692549499965682 * x + \
|
258 |
+
1.8759900014898907 * y + 0.0415559265582928 * z
|
259 |
+
b: torch.Tensor = 0.0556466391351772 * x + - \
|
260 |
+
0.2040413383665112 * y + 1.0573110696453443 * z
|
261 |
+
|
262 |
+
out: torch.Tensor = torch.stack([r, g, b], dim=-3)
|
263 |
+
|
264 |
+
return out
|
265 |
+
|
266 |
+
|
267 |
+
def rgb_to_linear_rgb(image: torch.Tensor) -> torch.Tensor:
|
268 |
+
r"""Convert an sRGB image to linear RGB. Used in colorspace conversions.
|
269 |
+
|
270 |
+
.. image:: _static/img/rgb_to_linear_rgb.png
|
271 |
+
|
272 |
+
Args:
|
273 |
+
image: sRGB Image to be converted to linear RGB of shape :math:`(*,3,H,W)`.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
linear RGB version of the image with shape of :math:`(*,3,H,W)`.
|
277 |
+
|
278 |
+
Example:
|
279 |
+
>>> input = torch.rand(2, 3, 4, 5)
|
280 |
+
>>> output = rgb_to_linear_rgb(input) # 2x3x4x5
|
281 |
+
"""
|
282 |
+
if not isinstance(image, torch.Tensor):
|
283 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
|
284 |
+
|
285 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
286 |
+
raise ValueError(
|
287 |
+
f"Input size must have a shape of (*, 3, H, W).Got {image.shape}")
|
288 |
+
|
289 |
+
lin_rgb: torch.Tensor = torch.where(image > 0.04045, torch.pow(
|
290 |
+
((image + 0.055) / 1.055), 2.4), image / 12.92)
|
291 |
+
|
292 |
+
return lin_rgb
|
293 |
+
|
294 |
+
|
295 |
+
def linear_rgb_to_rgb(image: torch.Tensor) -> torch.Tensor:
|
296 |
+
r"""Convert a linear RGB image to sRGB. Used in colorspace conversions.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
image: linear RGB Image to be converted to sRGB of shape :math:`(*,3,H,W)`.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
sRGB version of the image with shape of shape :math:`(*,3,H,W)`.
|
303 |
+
|
304 |
+
Example:
|
305 |
+
>>> input = torch.rand(2, 3, 4, 5)
|
306 |
+
>>> output = linear_rgb_to_rgb(input) # 2x3x4x5
|
307 |
+
"""
|
308 |
+
if not isinstance(image, torch.Tensor):
|
309 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
|
310 |
+
|
311 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
312 |
+
raise ValueError(
|
313 |
+
f"Input size must have a shape of (*, 3, H, W).Got {image.shape}")
|
314 |
+
|
315 |
+
threshold = 0.0031308
|
316 |
+
rgb: torch.Tensor = torch.where(
|
317 |
+
image > threshold, 1.055 *
|
318 |
+
torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image
|
319 |
+
)
|
320 |
+
|
321 |
+
return rgb
|
322 |
+
|
323 |
+
|
324 |
+
def inference_img(model, img, device='cpu'):
|
325 |
+
h, w, _ = img.shape
|
326 |
+
# print(img.shape)
|
327 |
+
if h % 8 != 0 or w % 8 != 0:
|
328 |
+
img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w %
|
329 |
+
8, 0, cv2.BORDER_REFLECT)
|
330 |
+
# print(img.shape)
|
331 |
+
|
332 |
+
tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
|
333 |
+
input_t = tensor_img
|
334 |
+
input_t = input_t/255.0
|
335 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
336 |
+
std=[0.229, 0.224, 0.225])
|
337 |
+
input_t = normalize(input_t)
|
338 |
+
input_t = input_t.unsqueeze(0).float()
|
339 |
+
with torch.no_grad():
|
340 |
+
out = model(input_t)
|
341 |
+
# print("out",out.shape)
|
342 |
+
result = out[0][:, -h:, -w:].cpu().numpy()
|
343 |
+
# print(result.shape)
|
344 |
+
|
345 |
+
return result[0]
|