huqiming513
commited on
Commit
•
e538b68
1
Parent(s):
a1a3080
Upload 14 files
Browse files- .gitignore +129 -0
- LICENSE +201 -0
- README.md +60 -3
- eval_Bread.py +218 -0
- exposure_augment.py +60 -0
- scripts.sh +42 -0
- test_Bread.py +195 -0
- test_Bread_NoNFM.py +167 -0
- train_ANSN.py +264 -0
- train_CAN.py +276 -0
- train_IAN.py +255 -0
- train_MECAN.py +261 -0
- train_MECAN_finetune.py +268 -0
- train_NFM.py +279 -0
.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,60 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## [Low-light Image Enhancement via Breaking Down the Darkness](https://arxiv.org/abs/2111.15557)
|
2 |
+
by Xiaojie Guo, Qiming Hu.
|
3 |
+
|
4 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mingcv/Bread/blob/main/bread_demo_uploader.ipynb) (Online Demo)
|
5 |
+
|
6 |
+
<!-- ![figure_tease](https://github.com/mingcv/Bread/blob/main/figures/figure_tease.png) -->
|
7 |
+
|
8 |
+
### 1. Dependencies
|
9 |
+
* Python3
|
10 |
+
* PyTorch>=1.0
|
11 |
+
* OpenCV-Python, TensorboardX
|
12 |
+
* NVIDIA GPU+CUDA
|
13 |
+
|
14 |
+
### 2. Network Architecture
|
15 |
+
![figure_arch](https://github.com/mingcv/Bread/blob/main/figures/Bread_architecture_full.png)
|
16 |
+
|
17 |
+
### 3. Data Preparation
|
18 |
+
|
19 |
+
#### 3.1. Training dataset
|
20 |
+
* 485 low/high-light image pairs from our485 of [LOL dataset](https://daooshee.github.io/BMVC2018website/), each low image of which is augmented by our [exposure_augment.py](https://github.com/mingcv/Bread/blob/main/exposure_augment.py) to generate 8 images under different exposures. ([Download Link for Augmented LOL](https://drive.google.com/file/d/1gyX2kYJWuj3C00eobd49MjRuNbZ29dqN/view?usp=sharing))
|
21 |
+
* To train the MECAN (if it is desired), 559 randomly-selected multi-exposure sequences from [SICE](https://github.com/csjcai/SICE) are adopted ([Download Link for a resized version](https://drive.google.com/file/d/1OTNP-QJ3Nade5my04A2iYVTY77IQBEMf/view?usp=sharing)).
|
22 |
+
|
23 |
+
#### 3.2. Tesing dataset
|
24 |
+
The images for testing can be downloaded in [this link](https://github.com/mingcv/Bread/releases/download/checkpoints/data.zip).
|
25 |
+
|
26 |
+
<!-- * 15 low/high-light image pairs from eval15 of [LOL dataset](https://daooshee.github.io/BMVC2018website/).
|
27 |
+
* 44 low-light images from DICM.
|
28 |
+
* 8 low-light images from NPE.
|
29 |
+
* 24 low-light images from VV. -->
|
30 |
+
|
31 |
+
### 4. Usage
|
32 |
+
|
33 |
+
#### 4.1. Training
|
34 |
+
* Multi-exposure data synthesis: ```python exposure_augment.py```
|
35 |
+
* Train IAN: ```python train_IAN.py -m IAN --comment IAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche```
|
36 |
+
* Train ANSN: ```python train_ANSN.py -m1 IAN -m2 ANSN --comment ANSN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche -m1w ./checkpoints/IAN_335.pth```
|
37 |
+
* Train CAN: ```python train_CAN.py -m1 IAN -m3 FuseNet --comment CAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche -m1w ./checkpoints/IAN_335.pth```
|
38 |
+
* Train MECAN on SICE: ```python train_MECAN.py -m FuseNet --comment MECAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche```
|
39 |
+
* Finetune MECAN on SICE and LOL datasets: ```python train_MECAN_finetune.py -m FuseNet --comment MECAN_finetune --batch_size 1 --val_interval 1 --num_epochs 500 --lr 1e-4 --no_sche -mw ./checkpoints/FuseNet_MECAN_for_Finetuning_404.pth```
|
40 |
+
|
41 |
+
#### 4.2. Testing
|
42 |
+
* *\[Tips\]: Using gamma correction for evaluation with parameter --gc; Show extra intermediate outputs with parameter --save_extra*
|
43 |
+
* Evaluation: ```python eval_Bread.py -m1 IAN -m2 ANSN -m3 FuseNet -m4 FuseNet --mef --comment Bread+NFM+ME[eval] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth -m4w ./checkpoints/FuseNet_NFM_297.pth```
|
44 |
+
* Testing: ```python test_Bread.py -m1 IAN -m2 ANSN -m3 FuseNet -m4 FuseNet --mef --comment Bread+NFM+ME[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth -m4w ./checkpoints/FuseNet_NFM_297.pth```
|
45 |
+
* Remove NFM: ```python test_Bread_NoNFM.py -m1 IAN -m2 ANSN -m3 FuseNet --mef -a 0.10 --comment Bread+ME[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth```
|
46 |
+
|
47 |
+
#### 4.3. Trained weights
|
48 |
+
Please refer to [our release](https://github.com/mingcv/Bread/releases/tag/checkpoints).
|
49 |
+
|
50 |
+
### 5. Quantitative comparison on eval15
|
51 |
+
![table_eval](https://github.com/mingcv/Bread/blob/main/figures/table_eval.png)
|
52 |
+
|
53 |
+
### 6. Visual comparison on eval15
|
54 |
+
![figure_eval](https://github.com/mingcv/Bread/blob/main/figures/figure_eval.png)
|
55 |
+
|
56 |
+
### 7. Visual comparison on DICM
|
57 |
+
![figure_test_dicm](https://github.com/mingcv/Bread/blob/main/figures/figure_test_dicm.png)
|
58 |
+
|
59 |
+
### 8. Visual comparison on VV and MEF-DS
|
60 |
+
![figure_test_vv_mefds](https://github.com/mingcv/Bread/blob/main/figures/figure_test_vv_mefds.png)
|
eval_Bread.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import kornia
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import tqdm
|
8 |
+
from torch import nn
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
import models
|
12 |
+
from datasets import LowLightDataset
|
13 |
+
from tools import saver, mutils
|
14 |
+
from models import PSNR, SSIM
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
20 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
21 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
22 |
+
parser.add_argument('--batch_size', type=int, default=4, help='The number of images per batch among all devices')
|
23 |
+
parser.add_argument('-m1', '--model1', type=str, default='IANet', help='Model1 Name')
|
24 |
+
parser.add_argument('-m2', '--model2', type=str, default='NSNet', help='Model2 Name')
|
25 |
+
parser.add_argument('-m3', '--model3', type=str, default='FuseNet', help='Model3 Name')
|
26 |
+
parser.add_argument('-m4', '--model4', type=str, default=None, help='Model4 Name')
|
27 |
+
|
28 |
+
parser.add_argument('-m1w', '--model1_weight', type=str, default=None, help='Model weight of IAN')
|
29 |
+
parser.add_argument('-m2w', '--model2_weight', type=str, default=None, help='Model weight of ANSN')
|
30 |
+
parser.add_argument('-m3w', '--model3_weight', type=str, default=None, help='Model weight of CAN')
|
31 |
+
parser.add_argument('-m4w', '--model4_weight', type=str, default=None, help='Model weight of NFM')
|
32 |
+
|
33 |
+
parser.add_argument('--mef', action='store_true', help='using color adation based MEF data or not')
|
34 |
+
parser.add_argument('--gc', action='store_true', help='using gamma correction or not')
|
35 |
+
parser.add_argument('--save_extra', action='store_true', help='save intermediate outputs or not')
|
36 |
+
|
37 |
+
parser.add_argument('--comment', type=str, default='default',
|
38 |
+
help='Project comment')
|
39 |
+
|
40 |
+
parser.add_argument('--alpha', '-a', type=float, default=0.10)
|
41 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
42 |
+
parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, '
|
43 |
+
'suggest using \'admaw\' until the'
|
44 |
+
' very final stage then switch to \'sgd\'')
|
45 |
+
parser.add_argument('--data_path', type=str, default='./data/LOL/eval',
|
46 |
+
help='the root folder of dataset')
|
47 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
48 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
49 |
+
args = parser.parse_args()
|
50 |
+
return args
|
51 |
+
|
52 |
+
|
53 |
+
class ModelBreadNet(nn.Module):
|
54 |
+
def __init__(self, model1, model2, model3, model4):
|
55 |
+
super().__init__()
|
56 |
+
self.eps = 1e-6
|
57 |
+
self.model_ianet = model1(in_channels=1, out_channels=1)
|
58 |
+
self.model_nsnet = model2(in_channels=2, out_channels=1)
|
59 |
+
self.model_canet = model3(in_channels=4, out_channels=2) if opt.mef else model3(in_channels=6, out_channels=2)
|
60 |
+
self.model_fdnet = model4(in_channels=3, out_channels=1) if opt.model4 else None
|
61 |
+
self.load_weight(self.model_ianet, opt.model1_weight)
|
62 |
+
self.load_weight(self.model_nsnet, opt.model2_weight)
|
63 |
+
self.load_weight(self.model_canet, opt.model3_weight)
|
64 |
+
self.load_weight(self.model_fdnet, opt.model4_weight)
|
65 |
+
|
66 |
+
def load_weight(self, model, weight_pth):
|
67 |
+
if model is not None:
|
68 |
+
state_dict = torch.load(weight_pth)
|
69 |
+
ret = model.load_state_dict(state_dict, strict=True)
|
70 |
+
print(ret)
|
71 |
+
|
72 |
+
def noise_syn_exp(self, illumi, strength):
|
73 |
+
return torch.exp(-illumi) * strength
|
74 |
+
|
75 |
+
def forward(self, image, image_gt):
|
76 |
+
# Color space mapping
|
77 |
+
texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
78 |
+
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
|
79 |
+
|
80 |
+
# Illumination prediction
|
81 |
+
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
|
82 |
+
texture_illumi = self.model_ianet(texture_in_down)
|
83 |
+
texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)
|
84 |
+
|
85 |
+
# Illumination adjustment
|
86 |
+
texture_illumi = torch.clamp(texture_illumi, 0., 1.)
|
87 |
+
texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps)
|
88 |
+
texture_ia = torch.clamp(texture_ia, 0., 1.)
|
89 |
+
|
90 |
+
# Noise suppression and fusion
|
91 |
+
texture_nss = []
|
92 |
+
for strength in [0., 0.05, 0.1]:
|
93 |
+
attention = self.noise_syn_exp(texture_illumi, strength=strength)
|
94 |
+
texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1))
|
95 |
+
texture_ns = texture_ia + texture_res
|
96 |
+
texture_nss.append(texture_ns)
|
97 |
+
texture_nss = torch.cat(texture_nss, dim=1).detach()
|
98 |
+
texture_fd = self.model_fdnet(texture_nss)
|
99 |
+
|
100 |
+
# Gamma correction to align the brightness with ground truth;
|
101 |
+
# other methods involved in our main paper are also conducted the same correction for evaluation.
|
102 |
+
if opt.gc:
|
103 |
+
max_psnr = 0
|
104 |
+
best = None
|
105 |
+
for ga in np.arange(0.1, 2.0, 0.01):
|
106 |
+
tx_en = texture_fd ** ga
|
107 |
+
psnr = PSNR(tx_en, texture_gt)
|
108 |
+
if psnr > max_psnr:
|
109 |
+
max_psnr = psnr
|
110 |
+
best = tx_en
|
111 |
+
|
112 |
+
texture_fd = torch.clamp(best, 0, 1)
|
113 |
+
|
114 |
+
# Color adaption
|
115 |
+
if not opt.mef:
|
116 |
+
image_ia_ycbcr = kornia.color.rgb_to_ycbcr(torch.clamp(image / (texture_illumi + self.eps), 0, 1))
|
117 |
+
_, cb_ia, cr_ia = torch.split(image_ia_ycbcr, 1, dim=1)
|
118 |
+
colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_fd, cb_ia, cr_ia], dim=1))
|
119 |
+
else:
|
120 |
+
colors = self.model_canet(
|
121 |
+
torch.cat([texture_in, cb_in, cr_in, texture_fd], dim=1))
|
122 |
+
|
123 |
+
cb_out, cr_out = torch.split(colors, 1, dim=1)
|
124 |
+
cb_out = torch.clamp(cb_out, 0, 1)
|
125 |
+
cr_out = torch.clamp(cr_out, 0, 1)
|
126 |
+
|
127 |
+
# Color space mapping
|
128 |
+
image_out = kornia.color.ycbcr_to_rgb(
|
129 |
+
torch.cat([texture_fd, cb_out, cr_out], dim=1))
|
130 |
+
image_out = torch.clamp(image_out, 0, 1)
|
131 |
+
|
132 |
+
# Calculating image quality metrics
|
133 |
+
psnr = PSNR(image_out, image_gt)
|
134 |
+
ssim = SSIM(image_out, image_gt).item()
|
135 |
+
|
136 |
+
return texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res, psnr, ssim
|
137 |
+
|
138 |
+
|
139 |
+
def evaluation(opt):
|
140 |
+
if torch.cuda.is_available():
|
141 |
+
torch.cuda.manual_seed(42)
|
142 |
+
else:
|
143 |
+
torch.manual_seed(42)
|
144 |
+
|
145 |
+
timestamp = mutils.get_formatted_time()
|
146 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
147 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
148 |
+
|
149 |
+
val_params = {'batch_size': 1,
|
150 |
+
'shuffle': False,
|
151 |
+
'drop_last': False,
|
152 |
+
'num_workers': opt.num_workers}
|
153 |
+
|
154 |
+
val_set = LowLightDataset(opt.data_path)
|
155 |
+
|
156 |
+
val_generator = DataLoader(val_set, **val_params)
|
157 |
+
val_generator = tqdm.tqdm(val_generator)
|
158 |
+
|
159 |
+
model1 = getattr(models, opt.model1)
|
160 |
+
model2 = getattr(models, opt.model2)
|
161 |
+
model3 = getattr(models, opt.model3)
|
162 |
+
model4 = getattr(models, opt.model4) if opt.model4 else None
|
163 |
+
|
164 |
+
model = ModelBreadNet(model1, model2, model3, model4)
|
165 |
+
print(model)
|
166 |
+
|
167 |
+
if opt.num_gpus > 0:
|
168 |
+
model = model.cuda()
|
169 |
+
if opt.num_gpus > 1:
|
170 |
+
model = nn.DataParallel(model)
|
171 |
+
|
172 |
+
model.eval()
|
173 |
+
psnrs, ssims, fns = [], [], []
|
174 |
+
for iter, (data, target, name) in enumerate(val_generator):
|
175 |
+
saver.base_url = os.path.join(opt.saved_path, 'results')
|
176 |
+
with torch.no_grad():
|
177 |
+
if opt.num_gpus == 1:
|
178 |
+
data = data.cuda()
|
179 |
+
target = target.cuda()
|
180 |
+
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1)
|
181 |
+
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1)
|
182 |
+
texture_ia, texture_nss, texture_fd, image_out, \
|
183 |
+
texture_illumi, texture_res, psnr, ssim = model(data, target)
|
184 |
+
if opt.save_extra:
|
185 |
+
saver.save_image(data, name=os.path.splitext(name[0])[0] + '_im_in')
|
186 |
+
saver.save_image(target, name=os.path.splitext(name[0])[0] + '_im_gt')
|
187 |
+
|
188 |
+
saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_y_in')
|
189 |
+
saver.save_image(texture_gt, name=os.path.splitext(name[0])[0] + '_y_gt')
|
190 |
+
|
191 |
+
saver.save_image(texture_ia, name=os.path.splitext(name[0])[0] + '_ia')
|
192 |
+
for i in range(texture_nss.shape[1]):
|
193 |
+
saver.save_image(texture_nss[:, i, ...], name=os.path.splitext(name[0])[0] + f'_ns_{i}')
|
194 |
+
saver.save_image(texture_fd, name=os.path.splitext(name[0])[0] + '_fd')
|
195 |
+
|
196 |
+
saver.save_image(texture_illumi, name=os.path.splitext(name[0])[0] + '_illumi')
|
197 |
+
saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
|
198 |
+
|
199 |
+
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out')
|
200 |
+
else:
|
201 |
+
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_Bread')
|
202 |
+
|
203 |
+
psnrs.append(psnr)
|
204 |
+
ssims.append(ssim)
|
205 |
+
fns.append(name[0])
|
206 |
+
|
207 |
+
results = list(zip(psnrs, ssims, fns))
|
208 |
+
results.sort(key=lambda item: item[0])
|
209 |
+
for r in results:
|
210 |
+
print(*r)
|
211 |
+
psnr = np.mean(np.array(psnrs))
|
212 |
+
ssim = np.mean(np.array(ssims))
|
213 |
+
print('psnr: ', psnr, ', ssim: ', ssim)
|
214 |
+
|
215 |
+
|
216 |
+
if __name__ == '__main__':
|
217 |
+
opt = get_args()
|
218 |
+
evaluation(opt)
|
exposure_augment.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import PIL.Image as Image
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms as vtrans
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def main(fip, fod):
|
12 |
+
max_overex_rate = 0.25
|
13 |
+
steps = 20
|
14 |
+
num_gen = 4
|
15 |
+
|
16 |
+
im = Image.open(fip)
|
17 |
+
im = vtrans.ToTensor()(im)
|
18 |
+
im_max = torch.flatten(torch.max(im, dim=0, keepdim=True).values)
|
19 |
+
mag = 1. / torch.topk(im_max, math.floor(len(im_max) * max_overex_rate + 1)).values
|
20 |
+
mag = mag[range(0, len(mag), int(len(mag) * (1. / steps)))]
|
21 |
+
mag_diff = torch.diff(mag, 1)
|
22 |
+
mag = mag[:-1]
|
23 |
+
|
24 |
+
top_mag_diff = torch.topk(mag_diff, num_gen).values
|
25 |
+
min_gain = top_mag_diff[top_mag_diff > 0][-1]
|
26 |
+
min_mag = mag[0]
|
27 |
+
max_mag = mag[mag_diff > min_gain][-1]
|
28 |
+
fn, ext = os.path.basename(fip).split('.')
|
29 |
+
bar.set_description(f'{fn}: {min_gain}')
|
30 |
+
ma = np.arange(1, min_mag - min_gain, min_gain * 2)
|
31 |
+
if len(ma) > num_gen:
|
32 |
+
mags = np.append(np.linspace(1, min_mag - min_gain, num_gen),
|
33 |
+
np.linspace(min_mag, max_mag, num_gen))
|
34 |
+
elif len(ma) == num_gen:
|
35 |
+
mags = np.append(ma, np.linspace(min_mag, max_mag, num_gen))
|
36 |
+
else:
|
37 |
+
mags = np.linspace(1, max_mag, num_gen * 2)
|
38 |
+
|
39 |
+
im = Image.open(fip)
|
40 |
+
im_raw = vtrans.ToTensor()(im)
|
41 |
+
|
42 |
+
for i, mag in enumerate(mags):
|
43 |
+
im = im_raw * mag
|
44 |
+
im.clamp_max_(1.)
|
45 |
+
fop = os.path.join(fod, f'{fn}_{i}.{ext}')
|
46 |
+
|
47 |
+
if not os.path.exists(fop):
|
48 |
+
vtrans.ToPILImage()(im).save(fop)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
# one needs to download it online
|
53 |
+
fid = './data/LOL/train/images'
|
54 |
+
fod = './data/LOL/train/images_aug'
|
55 |
+
os.makedirs(fod, exist_ok=True)
|
56 |
+
|
57 |
+
bar = tqdm.tqdm(os.listdir(fid))
|
58 |
+
for fn in bar:
|
59 |
+
fip = os.path.join(fid, fn)
|
60 |
+
main(fip, fod)
|
scripts.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################
|
2 |
+
# Breaking Down the Darkness : Testing
|
3 |
+
#######################################
|
4 |
+
|
5 |
+
# Using gamma correction for evaluation with parameter --gc
|
6 |
+
# Show extra intermediate outputs with parameter --save_extra
|
7 |
+
CUDA_VISIBLE_DEVICES=0 python eval_Bread.py -m1 IAN -m2 ANSN -m3 FuseNet -m4 FuseNet --mef --comment Bread+NFM+ME[eval] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth -m4w ./checkpoints/FuseNet_NFM_297.pth
|
8 |
+
CUDA_VISIBLE_DEVICES=0 python test_Bread.py -m1 IAN -m2 ANSN -m3 FuseNet -m4 FuseNet --mef --comment Bread+NFM+ME[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth -m4w ./checkpoints/FuseNet_NFM_297.pth
|
9 |
+
|
10 |
+
# Using CAN w/o MEF data
|
11 |
+
CUDA_VISIBLE_DEVICES=0 python eval_Bread.py -m1 IAN -m2 ANSN -m3 IAN -m4 FuseNet --comment Bread+NFM[eval] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/IANet_CAN_51.pth -m4w ./checkpoints/FuseNet_NFM_297.pth
|
12 |
+
CUDA_VISIBLE_DEVICES=0 python test_Bread.py -m1 IAN -m2 ANSN -m3 IAN -m4 FuseNet --comment Bread+NFM[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/IANet_CAN_51.pth -m4w ./checkpoints/FuseNet_NFM_297.pth
|
13 |
+
|
14 |
+
# Remove NFM for much better generalization performance
|
15 |
+
# changing parameter a for an ideal denoising strength
|
16 |
+
CUDA_VISIBLE_DEVICES=0 python test_Bread_NoNFM.py -m1 IAN -m2 ANSN -m3 FuseNet --mef -a 0.10 --comment Bread+ME[test] --batch_size 1 -m1w ./checkpoints/IAN_335.pth -m2w ./checkpoints/ANSN_422.pth -m3w ./checkpoints/FuseNet_MECAN_251.pth
|
17 |
+
|
18 |
+
##############################################################
|
19 |
+
# Breaking Down the Darkness : Training
|
20 |
+
# SICE dataset and LOL dataset are need to be download online
|
21 |
+
##############################################################
|
22 |
+
|
23 |
+
# Multi-exposure data synthesis
|
24 |
+
python exposure_augment.py
|
25 |
+
|
26 |
+
# Train IAN
|
27 |
+
CUDA_VISIBLE_DEVICES=0 python train_IAN.py -m IAN --comment IAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche
|
28 |
+
|
29 |
+
# Train ANSN
|
30 |
+
CUDA_VISIBLE_DEVICES=0 python train_ANSN.py -m1 IAN -m2 ANSN --comment ANSN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche -m1w ./checkpoints/IAN_335.pth
|
31 |
+
|
32 |
+
# Train CAN
|
33 |
+
CUDA_VISIBLE_DEVICES=0 python train_CAN.py -m1 IAN -m3 FuseNet --comment CAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche -m1w ./checkpoints/IAN_335.pth
|
34 |
+
|
35 |
+
# Train MECAN on SICE
|
36 |
+
CUDA_VISIBLE_DEVICES=0 python train_MECAN.py -m FuseNet --comment MECAN_train --batch_size 1 --val_interval 1 --num_epochs 500 --lr 0.001 --no_sche
|
37 |
+
|
38 |
+
# Finetune MECAN on SICE and LOL datasets
|
39 |
+
CUDA_VISIBLE_DEVICES=0 python train_MECAN_finetune.py -m FuseNet --comment MECAN_finetune --batch_size 1 --val_interval 1 --num_epochs 500 --lr 1e-4 --no_sche -mw ./checkpoints/FuseNet_MECAN_for_Finetuning_404.pth
|
40 |
+
|
41 |
+
|
42 |
+
|
test_Bread.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import kornia
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import tqdm
|
8 |
+
from torch import nn
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
import models
|
12 |
+
from datasets import LowLightDatasetTest
|
13 |
+
from tools import saver, mutils
|
14 |
+
|
15 |
+
|
16 |
+
def get_args():
|
17 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
18 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
19 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
20 |
+
parser.add_argument('--batch_size', type=int, default=4, help='The number of images per batch among all devices')
|
21 |
+
parser.add_argument('-m1', '--model1', type=str, default='IANet', help='Model1 Name')
|
22 |
+
parser.add_argument('-m2', '--model2', type=str, default='NSNet', help='Model2 Name')
|
23 |
+
parser.add_argument('-m3', '--model3', type=str, default='FuseNet', help='Model3 Name')
|
24 |
+
parser.add_argument('-m4', '--model4', type=str, default=None, help='Model4 Name')
|
25 |
+
|
26 |
+
parser.add_argument('-m1w', '--model1_weight', type=str, default=None, help='Model weight of IAN')
|
27 |
+
parser.add_argument('-m2w', '--model2_weight', type=str, default=None, help='Model weight of ANSN')
|
28 |
+
parser.add_argument('-m3w', '--model3_weight', type=str, default=None, help='Model weight of CAN')
|
29 |
+
parser.add_argument('-m4w', '--model4_weight', type=str, default=None, help='Model weight of NFM')
|
30 |
+
|
31 |
+
parser.add_argument('--mef', action='store_true')
|
32 |
+
parser.add_argument('--save_extra', action='store_true', help='save intermediate outputs or not')
|
33 |
+
|
34 |
+
parser.add_argument('--comment', type=str, default='default',
|
35 |
+
help='Project comment')
|
36 |
+
|
37 |
+
parser.add_argument('--alpha', '-a', type=float, default=0.10)
|
38 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
39 |
+
parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, '
|
40 |
+
'suggest using \'admaw\' until the'
|
41 |
+
' very final stage then switch to \'sgd\'')
|
42 |
+
parser.add_argument('--data_path', type=str, default='./data/test',
|
43 |
+
help='the root folder of dataset')
|
44 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
45 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
46 |
+
args = parser.parse_args()
|
47 |
+
return args
|
48 |
+
|
49 |
+
|
50 |
+
class ModelBreadNet(nn.Module):
|
51 |
+
def __init__(self, model1, model2, model3, model4):
|
52 |
+
super().__init__()
|
53 |
+
self.eps = 1e-6
|
54 |
+
self.model_ianet = model1(in_channels=1, out_channels=1)
|
55 |
+
self.model_nsnet = model2(in_channels=2, out_channels=1)
|
56 |
+
self.model_canet = model3(in_channels=4, out_channels=2) if opt.mef else model3(in_channels=6, out_channels=2)
|
57 |
+
self.model_fdnet = model4(in_channels=3, out_channels=1) if opt.model4 else None
|
58 |
+
|
59 |
+
self.load_weight(self.model_ianet, opt.model1_weight)
|
60 |
+
self.load_weight(self.model_nsnet, opt.model2_weight)
|
61 |
+
self.load_weight(self.model_canet, opt.model3_weight)
|
62 |
+
self.load_weight(self.model_fdnet, opt.model4_weight)
|
63 |
+
|
64 |
+
def load_weight(self, model, weight_pth):
|
65 |
+
if model is not None:
|
66 |
+
state_dict = torch.load(weight_pth)
|
67 |
+
ret = model.load_state_dict(state_dict, strict=True)
|
68 |
+
print(ret)
|
69 |
+
|
70 |
+
def noise_syn_exp(self, illumi, strength):
|
71 |
+
return torch.exp(-illumi) * strength
|
72 |
+
|
73 |
+
def forward(self, image):
|
74 |
+
# Color space mapping
|
75 |
+
texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
76 |
+
|
77 |
+
# Illumination prediction
|
78 |
+
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
|
79 |
+
texture_illumi = self.model_ianet(texture_in_down)
|
80 |
+
texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)
|
81 |
+
|
82 |
+
# Illumination adjustment
|
83 |
+
texture_illumi = torch.clamp(texture_illumi, 0., 1.)
|
84 |
+
texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps)
|
85 |
+
texture_ia = torch.clamp(texture_ia, 0., 1.)
|
86 |
+
|
87 |
+
# Noise suppression and fusion
|
88 |
+
texture_nss = []
|
89 |
+
for strength in [0., 0.05, 0.1]:
|
90 |
+
attention = self.noise_syn_exp(texture_illumi, strength=strength)
|
91 |
+
texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1))
|
92 |
+
texture_ns = texture_ia + texture_res
|
93 |
+
texture_nss.append(texture_ns)
|
94 |
+
texture_nss = torch.cat(texture_nss, dim=1).detach()
|
95 |
+
texture_fd = self.model_fdnet(texture_nss)
|
96 |
+
|
97 |
+
# Further preserve the texture under brighter illumination
|
98 |
+
texture_fd = texture_illumi * texture_in + (1 - texture_illumi) * texture_fd
|
99 |
+
texture_fd = torch.clamp(texture_fd, 0, 1)
|
100 |
+
|
101 |
+
# Color adaption
|
102 |
+
if not opt.mef:
|
103 |
+
image_ia_ycbcr = kornia.color.rgb_to_ycbcr(torch.clamp(image / (texture_illumi + self.eps), 0, 1))
|
104 |
+
_, cb_ia, cr_ia = torch.split(image_ia_ycbcr, 1, dim=1)
|
105 |
+
colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_fd, cb_ia, cr_ia], dim=1))
|
106 |
+
else:
|
107 |
+
colors = self.model_canet(
|
108 |
+
torch.cat([texture_in, cb_in, cr_in, texture_fd], dim=1))
|
109 |
+
cb_out, cr_out = torch.split(colors, 1, dim=1)
|
110 |
+
cb_out = torch.clamp(cb_out, 0, 1)
|
111 |
+
cr_out = torch.clamp(cr_out, 0, 1)
|
112 |
+
|
113 |
+
# Color space mapping
|
114 |
+
image_out = kornia.color.ycbcr_to_rgb(
|
115 |
+
torch.cat([texture_fd, cb_out, cr_out], dim=1))
|
116 |
+
|
117 |
+
# Further preserve the color under brighter illumination
|
118 |
+
img_fusion = texture_illumi * image + (1 - texture_illumi) * image_out
|
119 |
+
_, cb_fuse, cr_fuse = torch.split(kornia.color.rgb_to_ycbcr(img_fusion), 1, dim=1)
|
120 |
+
image_out = kornia.color.ycbcr_to_rgb(
|
121 |
+
torch.cat([texture_fd, cb_fuse, cr_fuse], dim=1))
|
122 |
+
image_out = torch.clamp(image_out, 0, 1)
|
123 |
+
|
124 |
+
return texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res
|
125 |
+
|
126 |
+
|
127 |
+
def test(opt):
|
128 |
+
if torch.cuda.is_available():
|
129 |
+
torch.cuda.manual_seed(42)
|
130 |
+
else:
|
131 |
+
torch.manual_seed(42)
|
132 |
+
|
133 |
+
timestamp = mutils.get_formatted_time()
|
134 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
135 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
136 |
+
|
137 |
+
test_params = {'batch_size': 1,
|
138 |
+
'shuffle': False,
|
139 |
+
'drop_last': False,
|
140 |
+
'num_workers': opt.num_workers}
|
141 |
+
|
142 |
+
test_set = LowLightDatasetTest(opt.data_path)
|
143 |
+
|
144 |
+
test_generator = DataLoader(test_set, **test_params)
|
145 |
+
test_generator = tqdm.tqdm(test_generator)
|
146 |
+
|
147 |
+
model1 = getattr(models, opt.model1)
|
148 |
+
model2 = getattr(models, opt.model2)
|
149 |
+
model3 = getattr(models, opt.model3)
|
150 |
+
model4 = getattr(models, opt.model4)
|
151 |
+
|
152 |
+
model = ModelBreadNet(model1, model2, model3, model4)
|
153 |
+
print(model)
|
154 |
+
|
155 |
+
if opt.num_gpus > 0:
|
156 |
+
model = model.cuda()
|
157 |
+
if opt.num_gpus > 1:
|
158 |
+
model = nn.DataParallel(model)
|
159 |
+
|
160 |
+
model.eval()
|
161 |
+
|
162 |
+
for iter, (data, subset, name) in enumerate(test_generator):
|
163 |
+
saver.base_url = os.path.join(opt.saved_path, 'results', subset[0])
|
164 |
+
with torch.no_grad():
|
165 |
+
if opt.num_gpus == 1:
|
166 |
+
data = data.cuda()
|
167 |
+
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1)
|
168 |
+
|
169 |
+
texture_ia, texture_nss, texture_fd, image_out, texture_illumi, texture_res = model(data)
|
170 |
+
|
171 |
+
if opt.save_extra:
|
172 |
+
saver.save_image(data, name=os.path.splitext(name[0])[0] + '_im_in')
|
173 |
+
saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_y_in')
|
174 |
+
saver.save_image(texture_ia, name=os.path.splitext(name[0])[0] + '_ia')
|
175 |
+
for i in range(texture_nss.shape[1]):
|
176 |
+
saver.save_image(texture_nss[:, i, ...], name=os.path.splitext(name[0])[0] + f'_ns_{i}')
|
177 |
+
saver.save_image(texture_fd, name=os.path.splitext(name[0])[0] + '_fd')
|
178 |
+
|
179 |
+
saver.save_image(texture_illumi, name=os.path.splitext(name[0])[0] + '_illumi')
|
180 |
+
saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
|
181 |
+
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out')
|
182 |
+
else:
|
183 |
+
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_Bread')
|
184 |
+
|
185 |
+
|
186 |
+
def save_checkpoint(model, name):
|
187 |
+
if isinstance(model, nn.DataParallel):
|
188 |
+
torch.save(model.module3.model_nsnet.state_dict(), os.path.join(opt.saved_path, name))
|
189 |
+
else:
|
190 |
+
torch.save(model.model_nsnet.state_dict(), os.path.join(opt.saved_path, name))
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == '__main__':
|
194 |
+
opt = get_args()
|
195 |
+
test(opt)
|
test_Bread_NoNFM.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import kornia
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import tqdm
|
8 |
+
from torch import nn
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
import models
|
12 |
+
from datasets import LowLightDatasetTest
|
13 |
+
from tools import saver, mutils
|
14 |
+
|
15 |
+
|
16 |
+
def get_args():
|
17 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
18 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
19 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
20 |
+
parser.add_argument('--batch_size', type=int, default=4, help='The number of images per batch among all devices')
|
21 |
+
parser.add_argument('-m1', '--model1', type=str, default='IAN', help='Model1 Name')
|
22 |
+
parser.add_argument('-m2', '--model2', type=str, default='ANSN', help='Model2 Name')
|
23 |
+
parser.add_argument('-m3', '--model3', type=str, default='FuseNet', help='Model3 Name')
|
24 |
+
|
25 |
+
parser.add_argument('-m1w', '--model1_weight', type=str, default=None, help='Model weight of IAN')
|
26 |
+
parser.add_argument('-m2w', '--model2_weight', type=str, default=None, help='Model weight of ANSN')
|
27 |
+
parser.add_argument('-m3w', '--model3_weight', type=str, default=None, help='Model weight of CAN')
|
28 |
+
|
29 |
+
parser.add_argument('--mef', action='store_true')
|
30 |
+
parser.add_argument('--save_extra', action='store_true', help='save intermediate outputs or not')
|
31 |
+
|
32 |
+
parser.add_argument('--comment', type=str, default='default',
|
33 |
+
help='Project comment')
|
34 |
+
|
35 |
+
parser.add_argument('--alpha', '-a', type=float, default=0.10)
|
36 |
+
|
37 |
+
parser.add_argument('--data_path', type=str, default='./data/test',
|
38 |
+
help='the root folder of dataset')
|
39 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
40 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
41 |
+
args = parser.parse_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
|
45 |
+
class ModelBreadNet(nn.Module):
|
46 |
+
def __init__(self, model1, model2, model3):
|
47 |
+
super().__init__()
|
48 |
+
self.eps = 1e-6
|
49 |
+
self.model_ianet = model1(in_channels=1, out_channels=1)
|
50 |
+
self.model_nsnet = model2(in_channels=2, out_channels=1)
|
51 |
+
self.model_canet = model3(in_channels=4, out_channels=2) if opt.mef else model3(in_channels=6, out_channels=2)
|
52 |
+
|
53 |
+
self.load_weight(self.model_ianet, opt.model1_weight)
|
54 |
+
self.load_weight(self.model_nsnet, opt.model2_weight)
|
55 |
+
self.load_weight(self.model_canet, opt.model3_weight)
|
56 |
+
|
57 |
+
def load_weight(self, model, weight_pth):
|
58 |
+
if model is not None:
|
59 |
+
state_dict = torch.load(weight_pth)
|
60 |
+
ret = model.load_state_dict(state_dict, strict=True)
|
61 |
+
print(ret)
|
62 |
+
|
63 |
+
def noise_syn_exp(self, illumi, strength):
|
64 |
+
return torch.exp(-illumi) * strength
|
65 |
+
|
66 |
+
def forward(self, image):
|
67 |
+
# Color space mapping
|
68 |
+
texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
69 |
+
|
70 |
+
# Illumination prediction
|
71 |
+
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
|
72 |
+
texture_illumi = self.model_ianet(texture_in_down)
|
73 |
+
texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)
|
74 |
+
|
75 |
+
# Illumination adjustment
|
76 |
+
texture_illumi = torch.clamp(texture_illumi, 0., 1.)
|
77 |
+
texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps)
|
78 |
+
texture_ia = torch.clamp(texture_ia, 0., 1.)
|
79 |
+
|
80 |
+
# Noise suppression and fusion
|
81 |
+
attention = self.noise_syn_exp(texture_illumi, strength=opt.alpha)
|
82 |
+
texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1))
|
83 |
+
texture_ns = texture_ia + texture_res
|
84 |
+
|
85 |
+
# Further preserve the texture under brighter illumination
|
86 |
+
texture_ns = texture_illumi * texture_in + (1 - texture_illumi) * texture_ns
|
87 |
+
texture_ns = torch.clamp(texture_ns, 0, 1)
|
88 |
+
|
89 |
+
# Color adaption
|
90 |
+
colors = self.model_canet(
|
91 |
+
torch.cat([texture_in, cb_in, cr_in, texture_ns], dim=1))
|
92 |
+
cb_out, cr_out = torch.split(colors, 1, dim=1)
|
93 |
+
cb_out = torch.clamp(cb_out, 0, 1)
|
94 |
+
cr_out = torch.clamp(cr_out, 0, 1)
|
95 |
+
|
96 |
+
# Color space mapping
|
97 |
+
image_out = kornia.color.ycbcr_to_rgb(
|
98 |
+
torch.cat([texture_ns, cb_out, cr_out], dim=1))
|
99 |
+
|
100 |
+
# Further preserve the color under brighter illumination
|
101 |
+
img_fusion = texture_illumi * image + (1 - texture_illumi) * image_out
|
102 |
+
_, cb_fuse, cr_fuse = torch.split(kornia.color.rgb_to_ycbcr(img_fusion), 1, dim=1)
|
103 |
+
image_out = kornia.color.ycbcr_to_rgb(
|
104 |
+
torch.cat([texture_ns, cb_fuse, cr_fuse], dim=1))
|
105 |
+
image_out = torch.clamp(image_out, 0, 1)
|
106 |
+
|
107 |
+
return texture_ia, texture_ns, image_out, texture_illumi, texture_res
|
108 |
+
|
109 |
+
|
110 |
+
def test(opt):
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
torch.cuda.manual_seed(42)
|
113 |
+
else:
|
114 |
+
torch.manual_seed(42)
|
115 |
+
|
116 |
+
timestamp = mutils.get_formatted_time()
|
117 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
118 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
119 |
+
|
120 |
+
test_params = {'batch_size': 1,
|
121 |
+
'shuffle': False,
|
122 |
+
'drop_last': False,
|
123 |
+
'num_workers': opt.num_workers}
|
124 |
+
|
125 |
+
test_set = LowLightDatasetTest(opt.data_path)
|
126 |
+
|
127 |
+
test_generator = DataLoader(test_set, **test_params)
|
128 |
+
test_generator = tqdm.tqdm(test_generator)
|
129 |
+
|
130 |
+
model1 = getattr(models, opt.model1)
|
131 |
+
model2 = getattr(models, opt.model2)
|
132 |
+
model3 = getattr(models, opt.model3)
|
133 |
+
|
134 |
+
model = ModelBreadNet(model1, model2, model3)
|
135 |
+
print(model)
|
136 |
+
|
137 |
+
if opt.num_gpus > 0:
|
138 |
+
model = model.cuda()
|
139 |
+
if opt.num_gpus > 1:
|
140 |
+
model = nn.DataParallel(model)
|
141 |
+
|
142 |
+
model.eval()
|
143 |
+
|
144 |
+
for iter, (data, subset, name) in enumerate(test_generator):
|
145 |
+
saver.base_url = os.path.join(opt.saved_path, 'results', subset[0])
|
146 |
+
with torch.no_grad():
|
147 |
+
if opt.num_gpus == 1:
|
148 |
+
data = data.cuda()
|
149 |
+
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1)
|
150 |
+
|
151 |
+
texture_ia, texture_ns, image_out, texture_illumi, texture_res = model(data)
|
152 |
+
|
153 |
+
if opt.save_extra:
|
154 |
+
saver.save_image(data, name=os.path.splitext(name[0])[0] + '_im_in')
|
155 |
+
saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_y_in')
|
156 |
+
saver.save_image(texture_ia, name=os.path.splitext(name[0])[0] + '_ia')
|
157 |
+
saver.save_image(texture_ns, name=os.path.splitext(name[0])[0] + '_ns')
|
158 |
+
|
159 |
+
saver.save_image(texture_illumi, name=os.path.splitext(name[0])[0] + '_illumi')
|
160 |
+
saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
|
161 |
+
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out')
|
162 |
+
else:
|
163 |
+
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_Bread')
|
164 |
+
|
165 |
+
if __name__ == '__main__':
|
166 |
+
opt = get_args()
|
167 |
+
test(opt)
|
train_ANSN.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import kornia
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from tqdm.autonotebook import tqdm
|
13 |
+
|
14 |
+
import models
|
15 |
+
from datasets import LowLightFDataset, LowLightFDatasetEval
|
16 |
+
from models import PSNR, SSIM, CosineLR
|
17 |
+
from tools import SingleSummaryWriter
|
18 |
+
from tools import saver, mutils
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
23 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
24 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
25 |
+
parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
|
26 |
+
parser.add_argument('-m1', '--model1', type=str, default='INet',
|
27 |
+
help='Model1 Name')
|
28 |
+
parser.add_argument('-m2', '--model2', type=str, default='NSNet',
|
29 |
+
help='Model1 Name')
|
30 |
+
parser.add_argument('-m1w', '--model1_weight', type=str, default=None,
|
31 |
+
help='Model Name')
|
32 |
+
|
33 |
+
parser.add_argument('--comment', type=str, default='default',
|
34 |
+
help='Project comment')
|
35 |
+
parser.add_argument('--graph', action='store_true')
|
36 |
+
parser.add_argument('--no_sche', action='store_true')
|
37 |
+
parser.add_argument('--sampling', action='store_true')
|
38 |
+
|
39 |
+
parser.add_argument('--slope', type=float, default=2.)
|
40 |
+
parser.add_argument('--lr', type=float, default=0.001)
|
41 |
+
parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
|
42 |
+
'suggest using \'admaw\' until the'
|
43 |
+
' very final stage then switch to \'sgd\'')
|
44 |
+
parser.add_argument('--num_epochs', type=int, default=500)
|
45 |
+
parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
|
46 |
+
parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
|
47 |
+
parser.add_argument('--data_path', type=str, default='./data/LOL',
|
48 |
+
help='the root folder of dataset')
|
49 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
50 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
51 |
+
args = parser.parse_args()
|
52 |
+
return args
|
53 |
+
|
54 |
+
|
55 |
+
class ModelNSNet(nn.Module):
|
56 |
+
def __init__(self, model1, model2):
|
57 |
+
super().__init__()
|
58 |
+
self.texture_loss = models.MSELoss()
|
59 |
+
self.model_ianet = model1(in_channels=1, out_channels=1)
|
60 |
+
self.model_nsnet = model2(in_channels=2, out_channels=1)
|
61 |
+
|
62 |
+
assert opt.model1_weight is not None
|
63 |
+
self.load_weight(self.model_ianet, opt.model1_weight)
|
64 |
+
self.model_ianet.eval()
|
65 |
+
self.eps = 1e-2
|
66 |
+
|
67 |
+
def load_weight(self, model, weight_pth):
|
68 |
+
state_dict = torch.load(weight_pth)
|
69 |
+
ret = model.load_state_dict(state_dict, strict=True)
|
70 |
+
print(ret)
|
71 |
+
|
72 |
+
def noise_syn(self, illumi, strength):
|
73 |
+
return torch.exp(-illumi) * strength
|
74 |
+
|
75 |
+
def forward(self, image, image_gt, training=True):
|
76 |
+
with torch.no_grad():
|
77 |
+
image = image.squeeze(0)
|
78 |
+
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
79 |
+
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
|
80 |
+
|
81 |
+
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
|
82 |
+
illumi = self.model_ianet(texture_in_down)
|
83 |
+
illumi = F.interpolate(illumi, scale_factor=2, mode='bicubic', align_corners=True)
|
84 |
+
|
85 |
+
attention = self.noise_syn(illumi, strength=0.1)
|
86 |
+
|
87 |
+
noise = torch.normal(mean=0., std=attention)
|
88 |
+
noisy_gt = torch.clamp(texture_gt + noise, 0., 1.)
|
89 |
+
|
90 |
+
texture_res = self.model_nsnet(torch.cat([noisy_gt, attention], dim=1))
|
91 |
+
restor_loss = self.texture_loss(texture_res, texture_gt - noisy_gt)
|
92 |
+
|
93 |
+
texture_ns = noisy_gt + texture_res
|
94 |
+
|
95 |
+
psnr = PSNR(texture_ns, texture_gt)
|
96 |
+
ssim = SSIM(texture_ns, texture_gt).item()
|
97 |
+
return noisy_gt, texture_ns, texture_res, illumi, restor_loss, psnr, ssim
|
98 |
+
|
99 |
+
|
100 |
+
def train(opt):
|
101 |
+
if torch.cuda.is_available():
|
102 |
+
torch.cuda.manual_seed(42)
|
103 |
+
else:
|
104 |
+
torch.manual_seed(42)
|
105 |
+
|
106 |
+
# params.project_name = params.project_name + str(time.time()).replace('.', '')
|
107 |
+
timestamp = mutils.get_formatted_time()
|
108 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
109 |
+
opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
|
110 |
+
os.makedirs(opt.log_path, exist_ok=True)
|
111 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
112 |
+
|
113 |
+
training_params = {'batch_size': opt.batch_size,
|
114 |
+
'shuffle': True,
|
115 |
+
'drop_last': True,
|
116 |
+
'num_workers': opt.num_workers}
|
117 |
+
|
118 |
+
val_params = {'batch_size': 1,
|
119 |
+
'shuffle': False,
|
120 |
+
'drop_last': True,
|
121 |
+
'num_workers': opt.num_workers}
|
122 |
+
|
123 |
+
training_set = LowLightFDataset(os.path.join(opt.data_path, 'train'))
|
124 |
+
training_generator = DataLoader(training_set, **training_params)
|
125 |
+
|
126 |
+
val_set = LowLightFDatasetEval(os.path.join(opt.data_path, 'eval'))
|
127 |
+
val_generator = DataLoader(val_set, **val_params)
|
128 |
+
|
129 |
+
model1 = getattr(models, opt.model1)
|
130 |
+
model2 = getattr(models, opt.model2)
|
131 |
+
writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
|
132 |
+
|
133 |
+
model = ModelNSNet(model1, model2)
|
134 |
+
print(model)
|
135 |
+
|
136 |
+
if opt.num_gpus > 0:
|
137 |
+
model = model.cuda()
|
138 |
+
if opt.num_gpus > 1:
|
139 |
+
model = nn.DataParallel(model)
|
140 |
+
|
141 |
+
if opt.optim == 'adam':
|
142 |
+
optimizer = torch.optim.Adam(model.model_nsnet.parameters(), opt.lr)
|
143 |
+
else:
|
144 |
+
optimizer = torch.optim.SGD(model.model_nsnet.parameters(), opt.lr, momentum=0.9, nesterov=True)
|
145 |
+
|
146 |
+
scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
|
147 |
+
epoch = 0
|
148 |
+
step = 0
|
149 |
+
model.model_nsnet.train()
|
150 |
+
|
151 |
+
num_iter_per_epoch = len(training_generator)
|
152 |
+
|
153 |
+
try:
|
154 |
+
for epoch in range(opt.num_epochs):
|
155 |
+
last_epoch = step // num_iter_per_epoch
|
156 |
+
if epoch < last_epoch:
|
157 |
+
continue
|
158 |
+
|
159 |
+
epoch_loss = []
|
160 |
+
progress_bar = tqdm(training_generator)
|
161 |
+
|
162 |
+
saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
|
163 |
+
if not opt.sampling:
|
164 |
+
for iter, (data, target, name) in enumerate(progress_bar):
|
165 |
+
if iter < step - last_epoch * num_iter_per_epoch:
|
166 |
+
progress_bar.update()
|
167 |
+
continue
|
168 |
+
try:
|
169 |
+
if opt.num_gpus == 1:
|
170 |
+
data = data.cuda()
|
171 |
+
target = target.cuda()
|
172 |
+
|
173 |
+
optimizer.zero_grad()
|
174 |
+
|
175 |
+
noisy_gt, texture_ns, texture_res, illumi, \
|
176 |
+
restor_loss, psnr, ssim = model(data, target, training=True)
|
177 |
+
|
178 |
+
loss = restor_loss
|
179 |
+
|
180 |
+
loss.backward()
|
181 |
+
optimizer.step()
|
182 |
+
|
183 |
+
epoch_loss.append(float(loss))
|
184 |
+
|
185 |
+
progress_bar.set_description(
|
186 |
+
'Step: {}. Epoch: {}/{}. Iteration: {}/{}. restor_loss: {:.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
187 |
+
step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, restor_loss.item(), psnr,
|
188 |
+
ssim))
|
189 |
+
writer.add_scalar('Loss/train', loss, step)
|
190 |
+
writer.add_scalar('PSNR/train', psnr, step)
|
191 |
+
writer.add_scalar('SSIM/train', ssim, step)
|
192 |
+
|
193 |
+
# log learning_rate
|
194 |
+
current_lr = optimizer.param_groups[0]['lr']
|
195 |
+
writer.add_scalar('learning_rate', current_lr, step)
|
196 |
+
|
197 |
+
step += 1
|
198 |
+
|
199 |
+
except Exception as e:
|
200 |
+
print('[Error]', traceback.format_exc())
|
201 |
+
print(e)
|
202 |
+
continue
|
203 |
+
|
204 |
+
if not opt.no_sche:
|
205 |
+
scheduler.step()
|
206 |
+
|
207 |
+
if epoch % opt.val_interval == 0:
|
208 |
+
model.model_nsnet.eval()
|
209 |
+
loss_ls = []
|
210 |
+
psnrs = []
|
211 |
+
ssims = []
|
212 |
+
|
213 |
+
for iter, (data, target, name) in enumerate(val_generator):
|
214 |
+
with torch.no_grad():
|
215 |
+
if opt.num_gpus == 1:
|
216 |
+
data = data.cuda()
|
217 |
+
target = target.cuda()
|
218 |
+
|
219 |
+
noisy_gt, texture_ns, texture_res, \
|
220 |
+
illumi, restor_loss, psnr, ssim = model(data, target, training=False)
|
221 |
+
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1)
|
222 |
+
|
223 |
+
saver.save_image(noisy_gt, name=os.path.splitext(name[0])[0] + '_in')
|
224 |
+
saver.save_image(texture_ns, name=os.path.splitext(name[0])[0] + '_ns')
|
225 |
+
saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
|
226 |
+
saver.save_image(illumi, name=os.path.splitext(name[0])[0] + '_ill')
|
227 |
+
saver.save_image(target, name=os.path.splitext(name[0])[0] + '_gt')
|
228 |
+
|
229 |
+
loss = restor_loss
|
230 |
+
loss_ls.append(loss.item())
|
231 |
+
psnrs.append(psnr)
|
232 |
+
ssims.append(ssim)
|
233 |
+
|
234 |
+
loss = np.mean(np.array(loss_ls))
|
235 |
+
psnr = np.mean(np.array(psnrs))
|
236 |
+
ssim = np.mean(np.array(ssims))
|
237 |
+
|
238 |
+
print(
|
239 |
+
'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
240 |
+
epoch, opt.num_epochs, loss, psnr, ssim))
|
241 |
+
writer.add_scalar('Loss/val', loss, step)
|
242 |
+
writer.add_scalar('PSNR/val', psnr, step)
|
243 |
+
writer.add_scalar('SSIM/val', ssim, step)
|
244 |
+
|
245 |
+
save_checkpoint(model, f'{opt.model2}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
|
246 |
+
|
247 |
+
model.model_nsnet.train()
|
248 |
+
|
249 |
+
except KeyboardInterrupt:
|
250 |
+
save_checkpoint(model, f'{opt.model2}_{epoch}_{step}_keyboardInterrupt.pth')
|
251 |
+
writer.close()
|
252 |
+
writer.close()
|
253 |
+
|
254 |
+
|
255 |
+
def save_checkpoint(model, name):
|
256 |
+
if isinstance(model, nn.DataParallel):
|
257 |
+
torch.save(model.module.model_nsnet.state_dict(), os.path.join(opt.saved_path, name))
|
258 |
+
else:
|
259 |
+
torch.save(model.model_nsnet.state_dict(), os.path.join(opt.saved_path, name))
|
260 |
+
|
261 |
+
|
262 |
+
if __name__ == '__main__':
|
263 |
+
opt = get_args()
|
264 |
+
train(opt)
|
train_CAN.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import kornia
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from tqdm.autonotebook import tqdm
|
13 |
+
|
14 |
+
import models
|
15 |
+
from datasets import LowLightFDataset, LowLightFDatasetEval
|
16 |
+
from models import PSNR, SSIM, CosineLR
|
17 |
+
from tools import SingleSummaryWriter
|
18 |
+
from tools import saver, mutils
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
23 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
24 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
25 |
+
parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
|
26 |
+
parser.add_argument('-m1', '--model1', type=str, default='INet',
|
27 |
+
help='Model Name')
|
28 |
+
parser.add_argument('-m3', '--model3', type=str, default='INet',
|
29 |
+
help='Model Name')
|
30 |
+
parser.add_argument('-m1w', '--model1_weight', type=str, default=None,
|
31 |
+
help='Model Name')
|
32 |
+
parser.add_argument('-m3w', '--model3_weight', type=str, default=None,
|
33 |
+
help='Model Name')
|
34 |
+
parser.add_argument('-ts', '--targets_split', type=str, default='targets',
|
35 |
+
help='dir of targets')
|
36 |
+
parser.add_argument('--comment', type=str, default='default',
|
37 |
+
help='Project comment')
|
38 |
+
parser.add_argument('--graph', action='store_true')
|
39 |
+
parser.add_argument('--scratch', action='store_true')
|
40 |
+
parser.add_argument('--sampling', action='store_true')
|
41 |
+
parser.add_argument('--test_on_start', action='store_true')
|
42 |
+
|
43 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
44 |
+
parser.add_argument('--no_sche', action='store_true')
|
45 |
+
|
46 |
+
parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
|
47 |
+
'suggest using \'admaw\' until the'
|
48 |
+
' very final stage then switch to \'sgd\'')
|
49 |
+
parser.add_argument('--num_epochs', type=int, default=500)
|
50 |
+
parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
|
51 |
+
parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
|
52 |
+
parser.add_argument('--data_path', type=str, default='./data/LOL',
|
53 |
+
help='the root folder of dataset')
|
54 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
55 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
56 |
+
args = parser.parse_args()
|
57 |
+
return args
|
58 |
+
|
59 |
+
|
60 |
+
def compute_gradient(img):
|
61 |
+
gradx = img[..., 1:, :] - img[..., :-1, :]
|
62 |
+
grady = img[..., 1:] - img[..., :-1]
|
63 |
+
return gradx, grady
|
64 |
+
|
65 |
+
|
66 |
+
class ModelCANet(nn.Module):
|
67 |
+
def __init__(self, model1, model3):
|
68 |
+
super().__init__()
|
69 |
+
self.color_loss = models.L1Loss()
|
70 |
+
self.restor_loss = models.MSSSIML1Loss(channels=3)
|
71 |
+
self.model_ianet = model1(in_channels=1, out_channels=1)
|
72 |
+
self.model_canet = model3(in_channels=6, out_channels=2)
|
73 |
+
self.eps = 1e-2
|
74 |
+
self.load_weight(self.model_ianet, opt.model1_weight)
|
75 |
+
if opt.model3_weight is not None:
|
76 |
+
self.load_weight(self.model_canet, opt.model3_weight)
|
77 |
+
self.model_ianet.eval()
|
78 |
+
|
79 |
+
def load_weight(self, model, weight_pth):
|
80 |
+
state_dict = torch.load(weight_pth)
|
81 |
+
ret = model.load_state_dict(state_dict, strict=True)
|
82 |
+
print(ret)
|
83 |
+
|
84 |
+
def forward(self, image, image_gt, training=True):
|
85 |
+
if training:
|
86 |
+
image = image.squeeze(0)
|
87 |
+
image_gt = image_gt.repeat(8, 1, 1, 1)
|
88 |
+
|
89 |
+
texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
90 |
+
|
91 |
+
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
|
92 |
+
texture_illumi = self.model_ianet(texture_in_down)
|
93 |
+
texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)
|
94 |
+
|
95 |
+
texture_en, cb_en, cr_en = torch.split(kornia.color.rgb_to_ycbcr(image / torch.clamp_min(texture_illumi, self.eps)),
|
96 |
+
1, dim=1)
|
97 |
+
texture_gt, cb_gt, cr_gt = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
|
98 |
+
|
99 |
+
colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_gt, cb_en, cr_en], dim=1))
|
100 |
+
|
101 |
+
cb, cr = torch.split(colors, 1, dim=1)
|
102 |
+
|
103 |
+
color_loss1 = self.color_loss(cb, cb_gt)
|
104 |
+
color_loss2 = self.color_loss(cr, cr_gt)
|
105 |
+
|
106 |
+
image_out = kornia.color.ycbcr_to_rgb(torch.cat([texture_gt, cb, cr], dim=1))
|
107 |
+
restor_loss = self.restor_loss(image_out, image_gt) * 1.0
|
108 |
+
|
109 |
+
psnr = PSNR(image_out, image_gt)
|
110 |
+
ssim = SSIM(image_out, image_gt).item()
|
111 |
+
return image_out, color_loss1, color_loss2, restor_loss, psnr, ssim
|
112 |
+
|
113 |
+
|
114 |
+
def train(opt):
|
115 |
+
if torch.cuda.is_available():
|
116 |
+
torch.cuda.manual_seed(42)
|
117 |
+
else:
|
118 |
+
torch.manual_seed(42)
|
119 |
+
|
120 |
+
timestamp = mutils.get_formatted_time()
|
121 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
122 |
+
opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
|
123 |
+
os.makedirs(opt.log_path, exist_ok=True)
|
124 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
125 |
+
|
126 |
+
training_params = {'batch_size': opt.batch_size,
|
127 |
+
'shuffle': True,
|
128 |
+
'drop_last': True,
|
129 |
+
'num_workers': opt.num_workers}
|
130 |
+
|
131 |
+
val_params = {'batch_size': 1,
|
132 |
+
'shuffle': False,
|
133 |
+
'drop_last': False,
|
134 |
+
'num_workers': opt.num_workers}
|
135 |
+
|
136 |
+
training_set = LowLightFDataset(os.path.join(opt.data_path, 'train'), targets_split=opt.targets_split,
|
137 |
+
training=True)
|
138 |
+
training_generator = DataLoader(training_set, **training_params)
|
139 |
+
|
140 |
+
val_set = LowLightFDatasetEval(os.path.join(opt.data_path, 'eval'), training=False)
|
141 |
+
val_generator = DataLoader(val_set, **val_params)
|
142 |
+
|
143 |
+
model1 = getattr(models, opt.model1)
|
144 |
+
model3 = getattr(models, opt.model3)
|
145 |
+
|
146 |
+
model = ModelCANet(model1, model3)
|
147 |
+
print(model)
|
148 |
+
|
149 |
+
writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
|
150 |
+
|
151 |
+
if opt.num_gpus > 0:
|
152 |
+
model = model.cuda()
|
153 |
+
if opt.num_gpus > 1:
|
154 |
+
model = nn.DataParallel(model)
|
155 |
+
|
156 |
+
if opt.optim == 'adam':
|
157 |
+
optimizer = torch.optim.Adam(model.model_canet.parameters(), opt.lr)
|
158 |
+
else:
|
159 |
+
optimizer = torch.optim.SGD(model.model_canet.parameters(), opt.lr, momentum=0.9, nesterov=True)
|
160 |
+
|
161 |
+
scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
|
162 |
+
epoch = 0
|
163 |
+
step = 0
|
164 |
+
model.model_canet.train()
|
165 |
+
|
166 |
+
num_iter_per_epoch = len(training_generator)
|
167 |
+
|
168 |
+
try:
|
169 |
+
for epoch in range(opt.num_epochs):
|
170 |
+
last_epoch = step // num_iter_per_epoch
|
171 |
+
if epoch < last_epoch:
|
172 |
+
continue
|
173 |
+
|
174 |
+
epoch_loss = []
|
175 |
+
progress_bar = tqdm(training_generator)
|
176 |
+
if not opt.sampling and not opt.test_on_start:
|
177 |
+
for iter, (data, target, name) in enumerate(progress_bar):
|
178 |
+
if iter < step - last_epoch * num_iter_per_epoch:
|
179 |
+
progress_bar.update()
|
180 |
+
continue
|
181 |
+
try:
|
182 |
+
if opt.num_gpus == 1:
|
183 |
+
data, target = data.cuda(), target.cuda()
|
184 |
+
optimizer.zero_grad()
|
185 |
+
|
186 |
+
image_out, color_loss1, color_loss2, \
|
187 |
+
restor_loss, psnr, ssim = model(data, target, training=True)
|
188 |
+
loss = color_loss1 + color_loss2 + restor_loss
|
189 |
+
loss.backward()
|
190 |
+
optimizer.step()
|
191 |
+
|
192 |
+
epoch_loss.append(float(loss))
|
193 |
+
|
194 |
+
progress_bar.set_description(
|
195 |
+
'Step: {}. Epoch: {}/{}. Iteration: {}/{}. color_loss1: {:1.5f}, color_loss2: {:1.5f}, restor_loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
196 |
+
step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch,
|
197 |
+
color_loss1.item(), color_loss2.item(),
|
198 |
+
restor_loss.item(), psnr, ssim))
|
199 |
+
writer.add_scalar('Loss/train', loss, step)
|
200 |
+
writer.add_scalar('PSNR/train', psnr, step)
|
201 |
+
writer.add_scalar('SSIM/train', ssim, step)
|
202 |
+
|
203 |
+
# log learning_rate
|
204 |
+
current_lr = optimizer.param_groups[0]['lr']
|
205 |
+
writer.add_scalar('learning_rate', current_lr, step)
|
206 |
+
|
207 |
+
step += 1
|
208 |
+
|
209 |
+
except Exception as e:
|
210 |
+
print('[Error]', traceback.format_exc())
|
211 |
+
print(e)
|
212 |
+
continue
|
213 |
+
# scheduler.step(np.mean(epoch_loss))
|
214 |
+
|
215 |
+
if opt.no_sche:
|
216 |
+
scheduler.step()
|
217 |
+
|
218 |
+
saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
|
219 |
+
|
220 |
+
if epoch % opt.val_interval == 0:
|
221 |
+
model.model_canet.eval()
|
222 |
+
loss_ls = []
|
223 |
+
psnrs = []
|
224 |
+
ssims = []
|
225 |
+
|
226 |
+
for iter, (data, target, name) in enumerate(val_generator):
|
227 |
+
with torch.no_grad():
|
228 |
+
if opt.num_gpus == 1:
|
229 |
+
data = data.squeeze(0).cuda()
|
230 |
+
target = target.cuda()
|
231 |
+
|
232 |
+
image_out, color_loss1, color_loss2, restor_loss, \
|
233 |
+
psnr, ssim = model(data, target, training=False)
|
234 |
+
saver.save_image(image_out, name=os.path.splitext(name[0])[0] + '_out')
|
235 |
+
saver.save_image(data, name=os.path.splitext(name[0])[0] + '_in')
|
236 |
+
saver.save_image(target, name=os.path.splitext(name[0])[0] + '_gt')
|
237 |
+
|
238 |
+
loss = restor_loss + color_loss1 + color_loss2
|
239 |
+
loss_ls.append(loss.item())
|
240 |
+
psnrs.append(psnr)
|
241 |
+
ssims.append(ssim)
|
242 |
+
|
243 |
+
loss = np.mean(np.array(loss_ls))
|
244 |
+
psnr = np.mean(np.array(psnrs))
|
245 |
+
ssim = np.mean(np.array(ssims))
|
246 |
+
|
247 |
+
print(
|
248 |
+
'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
249 |
+
epoch, opt.num_epochs, loss, psnr, ssim))
|
250 |
+
writer.add_scalar('Loss/val', loss, step)
|
251 |
+
writer.add_scalar('PSNR/val', psnr, step)
|
252 |
+
writer.add_scalar('SSIM/val', ssim, step)
|
253 |
+
|
254 |
+
save_checkpoint(model, f'{opt.model3}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
|
255 |
+
|
256 |
+
model.model_canet.train()
|
257 |
+
|
258 |
+
opt.test_on_start = False
|
259 |
+
if opt.sampling:
|
260 |
+
exit(0)
|
261 |
+
except KeyboardInterrupt:
|
262 |
+
save_checkpoint(model, f'{opt.model3}_{epoch}_{step}_keyboardInterrupt.pth')
|
263 |
+
writer.close()
|
264 |
+
writer.close()
|
265 |
+
|
266 |
+
|
267 |
+
def save_checkpoint(model, name):
|
268 |
+
if isinstance(model, nn.DataParallel):
|
269 |
+
torch.save(model.module.model_canet.state_dict(), os.path.join(opt.saved_path, name))
|
270 |
+
else:
|
271 |
+
torch.save(model.model_canet.state_dict(), os.path.join(opt.saved_path, name))
|
272 |
+
|
273 |
+
|
274 |
+
if __name__ == '__main__':
|
275 |
+
opt = get_args()
|
276 |
+
train(opt)
|
train_IAN.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import kornia
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from tqdm.autonotebook import tqdm
|
13 |
+
|
14 |
+
import models
|
15 |
+
from datasets import LowLightDataset, LowLightFDataset
|
16 |
+
from models import PSNR, SSIM, CosineLR
|
17 |
+
from tools import SingleSummaryWriter
|
18 |
+
from tools import saver, mutils
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
23 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
24 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
25 |
+
parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
|
26 |
+
parser.add_argument('-m', '--model', type=str, default='INet',
|
27 |
+
help='Model Name')
|
28 |
+
parser.add_argument('--comment', type=str, default='default',
|
29 |
+
help='Project comment')
|
30 |
+
parser.add_argument('--graph', action='store_true')
|
31 |
+
parser.add_argument('--scratch', action='store_true')
|
32 |
+
|
33 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
34 |
+
parser.add_argument('--no_sche', action='store_true')
|
35 |
+
|
36 |
+
parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
|
37 |
+
'suggest using \'admaw\' until the'
|
38 |
+
' very final stage then switch to \'sgd\'')
|
39 |
+
parser.add_argument('--num_epochs', type=int, default=500)
|
40 |
+
parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
|
41 |
+
parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
|
42 |
+
parser.add_argument('--data_path', type=str, default='./data/LOL',
|
43 |
+
help='the root folder of dataset')
|
44 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
45 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
46 |
+
args = parser.parse_args()
|
47 |
+
return args
|
48 |
+
|
49 |
+
|
50 |
+
def compute_gradient(img):
|
51 |
+
gradx = img[..., 1:, :] - img[..., :-1, :]
|
52 |
+
grady = img[..., 1:] - img[..., :-1]
|
53 |
+
return gradx, grady
|
54 |
+
|
55 |
+
|
56 |
+
class ModelINet(nn.Module):
|
57 |
+
def __init__(self, model):
|
58 |
+
super().__init__()
|
59 |
+
self.restor_loss = models.MSELoss()
|
60 |
+
self.wtv_loss = models.WTVLoss2()
|
61 |
+
self.model = model(in_channels=1, out_channels=1)
|
62 |
+
self.eps = 1e-2
|
63 |
+
|
64 |
+
def forward(self, image, image_gt, training=True):
|
65 |
+
if training:
|
66 |
+
image = image.squeeze(0)
|
67 |
+
image_gt = image_gt.repeat(8, 1, 1, 1)
|
68 |
+
|
69 |
+
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
70 |
+
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
|
71 |
+
|
72 |
+
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
|
73 |
+
texture_gt_down = F.interpolate(texture_gt, scale_factor=0.5, mode='bicubic', align_corners=True)
|
74 |
+
|
75 |
+
illumi = self.model(texture_in_down)
|
76 |
+
|
77 |
+
texture_out = texture_in_down / torch.clamp_min(illumi, self.eps)
|
78 |
+
restor_loss = self.restor_loss(texture_out, texture_gt_down)
|
79 |
+
restor_loss += self.restor_loss(texture_in_down, texture_gt_down * illumi)
|
80 |
+
|
81 |
+
tv_loss = self.wtv_loss(illumi, texture_in_down)
|
82 |
+
if training:
|
83 |
+
psnr = 0.0
|
84 |
+
ssim = 0.0
|
85 |
+
else:
|
86 |
+
illumi = F.interpolate(illumi, scale_factor=2, mode='bicubic', align_corners=True)
|
87 |
+
texture_out = texture_in / torch.clamp_min(illumi, self.eps)
|
88 |
+
|
89 |
+
psnr = PSNR(texture_out, texture_gt)
|
90 |
+
ssim = SSIM(texture_out, texture_gt).item()
|
91 |
+
return texture_out, illumi, restor_loss, tv_loss, psnr, ssim
|
92 |
+
|
93 |
+
|
94 |
+
def train(opt):
|
95 |
+
if torch.cuda.is_available():
|
96 |
+
torch.cuda.manual_seed(42)
|
97 |
+
else:
|
98 |
+
torch.manual_seed(42)
|
99 |
+
|
100 |
+
timestamp = mutils.get_formatted_time()
|
101 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
102 |
+
opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
|
103 |
+
os.makedirs(opt.log_path, exist_ok=True)
|
104 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
105 |
+
|
106 |
+
training_params = {'batch_size': opt.batch_size,
|
107 |
+
'shuffle': True,
|
108 |
+
'drop_last': True,
|
109 |
+
'num_workers': opt.num_workers}
|
110 |
+
|
111 |
+
val_params = {'batch_size': 1,
|
112 |
+
'shuffle': False,
|
113 |
+
'drop_last': True,
|
114 |
+
'num_workers': opt.num_workers}
|
115 |
+
|
116 |
+
training_set = LowLightFDataset(os.path.join(opt.data_path, 'train'), image_split='images_aug',
|
117 |
+
targets_split='targets')
|
118 |
+
training_generator = DataLoader(training_set, **training_params)
|
119 |
+
|
120 |
+
val_set = LowLightDataset(os.path.join(opt.data_path, 'eval'), targets_split='targets')
|
121 |
+
val_generator = DataLoader(val_set, **val_params)
|
122 |
+
|
123 |
+
model = getattr(models, opt.model)
|
124 |
+
|
125 |
+
model = ModelINet(model)
|
126 |
+
print(model)
|
127 |
+
# load last weights
|
128 |
+
|
129 |
+
writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
|
130 |
+
|
131 |
+
if opt.num_gpus > 0:
|
132 |
+
model = model.cuda()
|
133 |
+
if opt.num_gpus > 1:
|
134 |
+
model = nn.DataParallel(model)
|
135 |
+
|
136 |
+
if opt.optim == 'adam':
|
137 |
+
optimizer = torch.optim.Adam(model.parameters(), opt.lr)
|
138 |
+
else:
|
139 |
+
optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)
|
140 |
+
|
141 |
+
scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
|
142 |
+
epoch = 0
|
143 |
+
step = 0
|
144 |
+
model.train()
|
145 |
+
|
146 |
+
num_iter_per_epoch = len(training_generator)
|
147 |
+
|
148 |
+
try:
|
149 |
+
for epoch in range(opt.num_epochs):
|
150 |
+
last_epoch = step // num_iter_per_epoch
|
151 |
+
if epoch < last_epoch:
|
152 |
+
continue
|
153 |
+
|
154 |
+
epoch_loss = []
|
155 |
+
progress_bar = tqdm(training_generator)
|
156 |
+
for iter, (data, target, name) in enumerate(progress_bar):
|
157 |
+
if iter < step - last_epoch * num_iter_per_epoch:
|
158 |
+
progress_bar.update()
|
159 |
+
continue
|
160 |
+
try:
|
161 |
+
if opt.num_gpus == 1:
|
162 |
+
data, target = data.cuda(), target.cuda()
|
163 |
+
|
164 |
+
optimizer.zero_grad()
|
165 |
+
|
166 |
+
texture_out, texture_attention, restor_loss, \
|
167 |
+
tv_loss, psnr, ssim = model(data, target, training=True)
|
168 |
+
loss = restor_loss + tv_loss
|
169 |
+
loss.backward()
|
170 |
+
optimizer.step()
|
171 |
+
|
172 |
+
epoch_loss.append(float(loss))
|
173 |
+
|
174 |
+
progress_bar.set_description(
|
175 |
+
'Step: {}. Epoch: {}/{}. Iteration: {}/{}. var: {:.5f}, res_loss: {:.5f}, tv_loss: {:.5f}, psnr: {:.3f}, ssim: {:.3f}'.format(
|
176 |
+
step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, torch.var(texture_attention),
|
177 |
+
restor_loss.item(),
|
178 |
+
tv_loss.item(), psnr, ssim))
|
179 |
+
writer.add_scalar('Loss/train', loss, step)
|
180 |
+
writer.add_scalar('PSNR/train', psnr, step)
|
181 |
+
writer.add_scalar('SSIM/train', ssim, step)
|
182 |
+
|
183 |
+
# log learning_rate
|
184 |
+
current_lr = optimizer.param_groups[0]['lr']
|
185 |
+
writer.add_scalar('learning_rate', current_lr, step)
|
186 |
+
|
187 |
+
step += 1
|
188 |
+
|
189 |
+
except Exception as e:
|
190 |
+
print('[Error]', traceback.format_exc())
|
191 |
+
print(e)
|
192 |
+
continue
|
193 |
+
|
194 |
+
if opt.no_sche:
|
195 |
+
scheduler.step()
|
196 |
+
|
197 |
+
saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
|
198 |
+
|
199 |
+
if epoch % opt.val_interval == 0:
|
200 |
+
model.eval()
|
201 |
+
loss_ls = []
|
202 |
+
psnrs = []
|
203 |
+
ssims = []
|
204 |
+
|
205 |
+
for iter, (data, target, name) in enumerate(val_generator):
|
206 |
+
with torch.no_grad():
|
207 |
+
if opt.num_gpus == 1:
|
208 |
+
data = data.cuda()
|
209 |
+
target = target.cuda()
|
210 |
+
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(data), 1, dim=1)
|
211 |
+
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1)
|
212 |
+
|
213 |
+
texture_out, texture_attention, restor_loss, \
|
214 |
+
tv_loss, psnr, ssim = model(data, target, training=False)
|
215 |
+
saver.save_image(texture_out, name=os.path.splitext(name[0])[0] + '_out')
|
216 |
+
saver.save_image(texture_in, name=os.path.splitext(name[0])[0] + '_in')
|
217 |
+
saver.save_image(texture_gt, name=os.path.splitext(name[0])[0] + '_gt')
|
218 |
+
saver.save_image(texture_attention, name=os.path.splitext(name[0])[0] + '_att')
|
219 |
+
|
220 |
+
loss = restor_loss + tv_loss
|
221 |
+
loss_ls.append(loss.item())
|
222 |
+
psnrs.append(psnr)
|
223 |
+
ssims.append(ssim)
|
224 |
+
|
225 |
+
loss = np.mean(np.array(loss_ls))
|
226 |
+
psnr = np.mean(np.array(psnrs))
|
227 |
+
ssim = np.mean(np.array(ssims))
|
228 |
+
|
229 |
+
print(
|
230 |
+
'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
231 |
+
epoch, opt.num_epochs, loss, psnr, ssim))
|
232 |
+
writer.add_scalar('Loss/val', loss, step)
|
233 |
+
writer.add_scalar('PSNR/val', psnr, step)
|
234 |
+
writer.add_scalar('SSIM/val', ssim, step)
|
235 |
+
|
236 |
+
save_checkpoint(model, f'{opt.model}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
|
237 |
+
|
238 |
+
model.train()
|
239 |
+
|
240 |
+
except KeyboardInterrupt:
|
241 |
+
save_checkpoint(model, f'{opt.model}_{epoch}_{step}_keyboardInterrupt.pth')
|
242 |
+
writer.close()
|
243 |
+
writer.close()
|
244 |
+
|
245 |
+
|
246 |
+
def save_checkpoint(model, name):
|
247 |
+
if isinstance(model, nn.DataParallel):
|
248 |
+
torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name))
|
249 |
+
else:
|
250 |
+
torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name))
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == '__main__':
|
254 |
+
opt = get_args()
|
255 |
+
train(opt)
|
train_MECAN.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import kornia
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tqdm.autonotebook import tqdm
|
12 |
+
|
13 |
+
import models
|
14 |
+
from datasets import MEFDataset
|
15 |
+
from models import PSNR, SSIM, CosineLR
|
16 |
+
from tools import SingleSummaryWriter
|
17 |
+
from tools import saver, mutils
|
18 |
+
|
19 |
+
|
20 |
+
def get_args():
|
21 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
22 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
23 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
24 |
+
parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
|
25 |
+
parser.add_argument('-m', '--model', type=str, default='INet',
|
26 |
+
help='Model Name')
|
27 |
+
parser.add_argument('-ts', '--targets_split', type=str, default='targets',
|
28 |
+
help='dir of targets')
|
29 |
+
parser.add_argument('--comment', type=str, default='default',
|
30 |
+
help='Project comment')
|
31 |
+
parser.add_argument('--graph', action='store_true')
|
32 |
+
parser.add_argument('--scratch', action='store_true')
|
33 |
+
parser.add_argument('--sampling', action='store_true')
|
34 |
+
parser.add_argument('--test_on_start', action='store_true')
|
35 |
+
|
36 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
37 |
+
parser.add_argument('--no_sche', action='store_true')
|
38 |
+
|
39 |
+
parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
|
40 |
+
'suggest using \'admaw\' until the'
|
41 |
+
' very final stage then switch to \'sgd\'')
|
42 |
+
parser.add_argument('--num_epochs', type=int, default=500)
|
43 |
+
parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
|
44 |
+
parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
|
45 |
+
parser.add_argument('--data_path', type=str, default='./data/SICE',
|
46 |
+
help='the root folder of dataset')
|
47 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
48 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
49 |
+
parser.add_argument('-r', '--resize', type=int, default=-1, help='resize of training images')
|
50 |
+
args = parser.parse_args()
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def compute_gradient(img):
|
55 |
+
gradx = img[..., 1:, :] - img[..., :-1, :]
|
56 |
+
grady = img[..., 1:] - img[..., :-1]
|
57 |
+
return gradx, grady
|
58 |
+
|
59 |
+
|
60 |
+
class ModelINet(nn.Module):
|
61 |
+
def __init__(self, model):
|
62 |
+
super().__init__()
|
63 |
+
self.color_loss = models.L1Loss()
|
64 |
+
self.restor_loss = models.MSSSIML1Loss(channels=3)
|
65 |
+
self.model_canet = model(in_channels=4, out_channels=2)
|
66 |
+
self.eps = 1e-2
|
67 |
+
|
68 |
+
def load_weight(self, model, weight_pth):
|
69 |
+
state_dict = torch.load(weight_pth)
|
70 |
+
ret = model.load_state_dict(state_dict, strict=True)
|
71 |
+
print(ret)
|
72 |
+
|
73 |
+
def forward(self, image, image_gt, training=True):
|
74 |
+
texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
75 |
+
texture_gt, cb_gt, cr_gt = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
|
76 |
+
|
77 |
+
colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_gt], dim=1))
|
78 |
+
cb, cr = torch.split(colors, 1, dim=1)
|
79 |
+
|
80 |
+
color_loss1 = self.color_loss(cb, cb_gt)
|
81 |
+
color_loss2 = self.color_loss(cr, cr_gt)
|
82 |
+
|
83 |
+
image_out = kornia.color.ycbcr_to_rgb(torch.cat([texture_gt, cb, cr], dim=1))
|
84 |
+
restor_loss = self.restor_loss(image_out, image_gt)
|
85 |
+
|
86 |
+
psnr = PSNR(image_out, image_gt)
|
87 |
+
ssim = SSIM(image_out, image_gt).item()
|
88 |
+
return image_out, color_loss1, color_loss2, restor_loss, psnr, ssim
|
89 |
+
|
90 |
+
|
91 |
+
def train(opt):
|
92 |
+
if torch.cuda.is_available():
|
93 |
+
torch.cuda.manual_seed(42)
|
94 |
+
else:
|
95 |
+
torch.manual_seed(42)
|
96 |
+
|
97 |
+
# params.project_name = params.project_name + str(time.time()).replace('.', '')
|
98 |
+
timestamp = mutils.get_formatted_time()
|
99 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
100 |
+
opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
|
101 |
+
os.makedirs(opt.log_path, exist_ok=True)
|
102 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
103 |
+
|
104 |
+
training_params = {'batch_size': opt.batch_size,
|
105 |
+
'shuffle': True,
|
106 |
+
'drop_last': True,
|
107 |
+
'num_workers': opt.num_workers}
|
108 |
+
|
109 |
+
val_params = {'batch_size': 1,
|
110 |
+
'shuffle': False,
|
111 |
+
'drop_last': True,
|
112 |
+
'num_workers': opt.num_workers}
|
113 |
+
|
114 |
+
training_set = MEFDataset(os.path.join(opt.data_path, 'train'))
|
115 |
+
training_generator = DataLoader(training_set, **training_params)
|
116 |
+
|
117 |
+
val_set = MEFDataset(os.path.join(opt.data_path, 'eval'))
|
118 |
+
val_generator = DataLoader(val_set, **val_params)
|
119 |
+
|
120 |
+
model = getattr(models, opt.model)
|
121 |
+
|
122 |
+
model = ModelINet(model)
|
123 |
+
print(model)
|
124 |
+
|
125 |
+
writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
|
126 |
+
|
127 |
+
if opt.num_gpus > 0:
|
128 |
+
model = model.cuda()
|
129 |
+
if opt.num_gpus > 1:
|
130 |
+
model = nn.DataParallel(model)
|
131 |
+
|
132 |
+
if opt.optim == 'adam':
|
133 |
+
optimizer = torch.optim.Adam(model.model_canet.parameters(), opt.lr)
|
134 |
+
else:
|
135 |
+
optimizer = torch.optim.SGD(model.model_canet.parameters(), opt.lr, momentum=0.9, nesterov=True)
|
136 |
+
|
137 |
+
scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
|
138 |
+
epoch = 0
|
139 |
+
step = 0
|
140 |
+
model.model_canet.train()
|
141 |
+
|
142 |
+
num_iter_per_epoch = len(training_generator)
|
143 |
+
|
144 |
+
try:
|
145 |
+
for epoch in range(opt.num_epochs):
|
146 |
+
last_epoch = step // num_iter_per_epoch
|
147 |
+
if epoch < last_epoch:
|
148 |
+
continue
|
149 |
+
|
150 |
+
epoch_loss = []
|
151 |
+
progress_bar = tqdm(training_generator)
|
152 |
+
if not opt.sampling and not opt.test_on_start:
|
153 |
+
for iter, (data, target, name1, name2) in enumerate(progress_bar):
|
154 |
+
if iter < step - last_epoch * num_iter_per_epoch:
|
155 |
+
progress_bar.update()
|
156 |
+
continue
|
157 |
+
try:
|
158 |
+
if opt.num_gpus == 1:
|
159 |
+
data, target = data.cuda(), target.cuda()
|
160 |
+
optimizer.zero_grad()
|
161 |
+
|
162 |
+
image_out, color_loss1, color_loss2, \
|
163 |
+
restor_loss, psnr, ssim = model(data, target, training=True)
|
164 |
+
loss = color_loss1 + color_loss2 + restor_loss
|
165 |
+
|
166 |
+
loss.backward()
|
167 |
+
optimizer.step()
|
168 |
+
|
169 |
+
epoch_loss.append(float(loss))
|
170 |
+
|
171 |
+
progress_bar.set_description(
|
172 |
+
'Step: {}. Epoch: {}/{}. Iteration: {}/{}. color_loss1: {:1.5f}, color_loss2: {:1.5f}, restor_loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
173 |
+
step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch,
|
174 |
+
color_loss1.item(), color_loss2.item(),
|
175 |
+
restor_loss.item(), psnr, ssim))
|
176 |
+
writer.add_scalar('Loss/train', loss, step)
|
177 |
+
writer.add_scalar('PSNR/train', psnr, step)
|
178 |
+
writer.add_scalar('SSIM/train', ssim, step)
|
179 |
+
|
180 |
+
# log learning_rate
|
181 |
+
current_lr = optimizer.param_groups[0]['lr']
|
182 |
+
writer.add_scalar('learning_rate', current_lr, step)
|
183 |
+
|
184 |
+
step += 1
|
185 |
+
|
186 |
+
except Exception as e:
|
187 |
+
print('[Error]', traceback.format_exc())
|
188 |
+
print(e)
|
189 |
+
continue
|
190 |
+
|
191 |
+
if opt.no_sche:
|
192 |
+
scheduler.step()
|
193 |
+
|
194 |
+
saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
|
195 |
+
|
196 |
+
if epoch % opt.val_interval == 0:
|
197 |
+
model.model_canet.eval()
|
198 |
+
loss_ls = []
|
199 |
+
psnrs = []
|
200 |
+
ssims = []
|
201 |
+
|
202 |
+
for iter, (data, target, name1, name2) in enumerate(val_generator):
|
203 |
+
with torch.no_grad():
|
204 |
+
if opt.num_gpus == 1:
|
205 |
+
data, target = data.cuda(), target.cuda()
|
206 |
+
|
207 |
+
image_out, color_loss1, color_loss2, restor_loss, \
|
208 |
+
psnr, ssim = model(data, target, training=False)
|
209 |
+
saver.save_image(data, name=os.path.splitext(name1[0])[0] + '_im1')
|
210 |
+
saver.save_image(target, name=os.path.splitext(name2[0])[0] + '_im2')
|
211 |
+
saver.save_image(image_out, name=os.path.splitext(name2[0])[0] + '_im2_pred')
|
212 |
+
|
213 |
+
loss = restor_loss + color_loss1 + color_loss2
|
214 |
+
loss_ls.append(loss.item())
|
215 |
+
psnrs.append(psnr)
|
216 |
+
ssims.append(ssim)
|
217 |
+
|
218 |
+
# reverse
|
219 |
+
image_out, color_loss1, color_loss2, restor_loss, \
|
220 |
+
psnr, ssim = model(target, data, training=False)
|
221 |
+
saver.save_image(image_out, name=os.path.splitext(name1[0])[0] + '_im1_pred')
|
222 |
+
|
223 |
+
loss = restor_loss + color_loss1 + color_loss2
|
224 |
+
loss_ls.append(loss.item())
|
225 |
+
psnrs.append(psnr)
|
226 |
+
ssims.append(ssim)
|
227 |
+
|
228 |
+
loss = np.mean(np.array(loss_ls))
|
229 |
+
psnr = np.mean(np.array(psnrs))
|
230 |
+
ssim = np.mean(np.array(ssims))
|
231 |
+
|
232 |
+
print(
|
233 |
+
'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
234 |
+
epoch, opt.num_epochs, loss, psnr, ssim))
|
235 |
+
writer.add_scalar('Loss/val', loss, step)
|
236 |
+
writer.add_scalar('PSNR/val', psnr, step)
|
237 |
+
writer.add_scalar('SSIM/val', ssim, step)
|
238 |
+
|
239 |
+
save_checkpoint(model, f'{opt.model}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
|
240 |
+
|
241 |
+
model.model_canet.train()
|
242 |
+
|
243 |
+
opt.test_on_start = False
|
244 |
+
if opt.sampling:
|
245 |
+
exit(0)
|
246 |
+
except KeyboardInterrupt:
|
247 |
+
save_checkpoint(model, f'{opt.model}_{epoch}_{step}_keyboardInterrupt.pth')
|
248 |
+
writer.close()
|
249 |
+
writer.close()
|
250 |
+
|
251 |
+
|
252 |
+
def save_checkpoint(model, name):
|
253 |
+
if isinstance(model, nn.DataParallel):
|
254 |
+
torch.save(model.module.model_canet.state_dict(), os.path.join(opt.saved_path, name))
|
255 |
+
else:
|
256 |
+
torch.save(model.model_canet.state_dict(), os.path.join(opt.saved_path, name))
|
257 |
+
|
258 |
+
|
259 |
+
if __name__ == '__main__':
|
260 |
+
opt = get_args()
|
261 |
+
train(opt)
|
train_MECAN_finetune.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# original author: signatrix
|
2 |
+
# adapted from https://github.com/signatrix/efficientdet/blob/master/train.py
|
3 |
+
# modified by Zylo117
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import datetime
|
7 |
+
import os
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
import kornia
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from tqdm.autonotebook import tqdm
|
16 |
+
|
17 |
+
import models
|
18 |
+
from datasets import MEFDataset, LowLightDataset, LowLightDatasetReverse
|
19 |
+
from models import PSNR, SSIM, CosineLR
|
20 |
+
from tools import SingleSummaryWriter
|
21 |
+
from tools import saver, mutils
|
22 |
+
|
23 |
+
|
24 |
+
def get_args():
|
25 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
26 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
27 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
28 |
+
parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
|
29 |
+
parser.add_argument('-m', '--model', type=str, default='INet',
|
30 |
+
help='Model Name')
|
31 |
+
parser.add_argument('-mw', '--model_weight', type=str, default=None,
|
32 |
+
help='Model weight')
|
33 |
+
parser.add_argument('-ts', '--targets_split', type=str, default='targets',
|
34 |
+
help='dir of targets')
|
35 |
+
parser.add_argument('--comment', type=str, default='default',
|
36 |
+
help='Project comment')
|
37 |
+
parser.add_argument('--graph', action='store_true')
|
38 |
+
parser.add_argument('--scratch', action='store_true')
|
39 |
+
parser.add_argument('--sampling', action='store_true')
|
40 |
+
parser.add_argument('--test_on_start', action='store_true')
|
41 |
+
|
42 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
43 |
+
parser.add_argument('--no_sche', action='store_true')
|
44 |
+
|
45 |
+
parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
|
46 |
+
'suggest using \'admaw\' until the'
|
47 |
+
' very final stage then switch to \'sgd\'')
|
48 |
+
parser.add_argument('--num_epochs', type=int, default=500)
|
49 |
+
parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
|
50 |
+
parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
|
51 |
+
parser.add_argument('--data_path1', type=str, default='./data/SICE',
|
52 |
+
help='the root folder of dataset')
|
53 |
+
parser.add_argument('--data_path2', type=str, default='./data/LOL',
|
54 |
+
help='the root folder of dataset')
|
55 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
56 |
+
|
57 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
58 |
+
args = parser.parse_args()
|
59 |
+
return args
|
60 |
+
|
61 |
+
|
62 |
+
def compute_gradient(img):
|
63 |
+
gradx = img[..., 1:, :] - img[..., :-1, :]
|
64 |
+
grady = img[..., 1:] - img[..., :-1]
|
65 |
+
return gradx, grady
|
66 |
+
|
67 |
+
|
68 |
+
class ModelINet(nn.Module):
|
69 |
+
def __init__(self, model):
|
70 |
+
super().__init__()
|
71 |
+
self.color_loss = models.SSIML1Loss(channels=1)
|
72 |
+
self.restor_loss = models.SSIML1Loss(channels=3)
|
73 |
+
self.model_canet = model(in_channels=4, out_channels=2)
|
74 |
+
self.eps = 1e-2
|
75 |
+
self.load_weight(self.model_canet, opt.model_weight)
|
76 |
+
|
77 |
+
def load_weight(self, model, weight_pth):
|
78 |
+
state_dict = torch.load(weight_pth)
|
79 |
+
ret = model.load_state_dict(state_dict, strict=True)
|
80 |
+
print(ret)
|
81 |
+
|
82 |
+
def forward(self, image, image_gt, training=True):
|
83 |
+
texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
84 |
+
texture_gt, cb_gt, cr_gt = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
|
85 |
+
|
86 |
+
colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_gt], dim=1))
|
87 |
+
cb, cr = torch.split(colors, 1, dim=1)
|
88 |
+
|
89 |
+
color_loss1 = self.color_loss(cb, cb_gt)
|
90 |
+
color_loss2 = self.color_loss(cr, cr_gt)
|
91 |
+
|
92 |
+
image_out = kornia.color.ycbcr_to_rgb(torch.cat([texture_gt, cb, cr], dim=1))
|
93 |
+
restor_loss = self.restor_loss(image_out, image_gt)
|
94 |
+
|
95 |
+
psnr = PSNR(image_out, image_gt)
|
96 |
+
ssim = SSIM(image_out, image_gt).item()
|
97 |
+
return image_out, color_loss1, color_loss2, restor_loss, psnr, ssim
|
98 |
+
|
99 |
+
|
100 |
+
def train(opt):
|
101 |
+
if torch.cuda.is_available():
|
102 |
+
torch.cuda.manual_seed(42)
|
103 |
+
else:
|
104 |
+
torch.manual_seed(42)
|
105 |
+
|
106 |
+
# params.project_name = params.project_name + str(time.time()).replace('.', '')
|
107 |
+
timestamp = mutils.get_formatted_time()
|
108 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
109 |
+
opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
|
110 |
+
os.makedirs(opt.log_path, exist_ok=True)
|
111 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
112 |
+
|
113 |
+
training_params = {'batch_size': opt.batch_size,
|
114 |
+
'shuffle': True,
|
115 |
+
'drop_last': True,
|
116 |
+
'num_workers': opt.num_workers}
|
117 |
+
|
118 |
+
val_params = {'batch_size': 1,
|
119 |
+
'shuffle': False,
|
120 |
+
'drop_last': True,
|
121 |
+
'num_workers': opt.num_workers}
|
122 |
+
|
123 |
+
training_set1 = MEFDataset(os.path.join(opt.data_path1, 'train'))
|
124 |
+
training_set2 = LowLightDataset(os.path.join(opt.data_path2, 'train'), color_tuning=True,
|
125 |
+
targets_split=opt.targets_split)
|
126 |
+
training_set3 = LowLightDatasetReverse(os.path.join(opt.data_path2, 'train'), color_tuning=True,
|
127 |
+
targets_split=opt.targets_split)
|
128 |
+
training_set = torch.utils.data.ConcatDataset([training_set1, training_set2, training_set3])
|
129 |
+
training_generator = DataLoader(training_set, **training_params)
|
130 |
+
|
131 |
+
# val_set = MEFDataset(os.path.join(opt.data_path1, 'eval'))
|
132 |
+
val_set = LowLightDataset(os.path.join(opt.data_path2, 'eval'), color_tuning=True)
|
133 |
+
# val_set = torch.utils.data.ConcatDataset([val_set1, val_set2])
|
134 |
+
val_generator = DataLoader(val_set, **val_params)
|
135 |
+
|
136 |
+
model = getattr(models, opt.model)
|
137 |
+
|
138 |
+
model = ModelINet(model)
|
139 |
+
print(model)
|
140 |
+
|
141 |
+
writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
|
142 |
+
|
143 |
+
|
144 |
+
if opt.num_gpus > 0:
|
145 |
+
model = model.cuda()
|
146 |
+
if opt.num_gpus > 1:
|
147 |
+
model = nn.DataParallel(model)
|
148 |
+
|
149 |
+
if opt.optim == 'adam':
|
150 |
+
optimizer = torch.optim.Adam(model.model_canet.parameters(), opt.lr)
|
151 |
+
else:
|
152 |
+
optimizer = torch.optim.SGD(model.model_canet.parameters(), opt.lr, momentum=0.9, nesterov=True)
|
153 |
+
|
154 |
+
scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
|
155 |
+
epoch = 0
|
156 |
+
step = 0
|
157 |
+
model.model_canet.train()
|
158 |
+
|
159 |
+
num_iter_per_epoch = len(training_generator)
|
160 |
+
|
161 |
+
try:
|
162 |
+
for epoch in range(opt.num_epochs):
|
163 |
+
last_epoch = step // num_iter_per_epoch
|
164 |
+
if epoch < last_epoch:
|
165 |
+
continue
|
166 |
+
|
167 |
+
epoch_loss = []
|
168 |
+
progress_bar = tqdm(training_generator)
|
169 |
+
if not opt.sampling and not opt.test_on_start:
|
170 |
+
for iter, (data, target, name1, name2) in enumerate(progress_bar):
|
171 |
+
if iter < step - last_epoch * num_iter_per_epoch:
|
172 |
+
progress_bar.update()
|
173 |
+
continue
|
174 |
+
try:
|
175 |
+
if opt.num_gpus == 1:
|
176 |
+
data, target = data.cuda(), target.cuda()
|
177 |
+
optimizer.zero_grad()
|
178 |
+
|
179 |
+
image_out, color_loss1, color_loss2, \
|
180 |
+
restor_loss, psnr, ssim = model(data, target, training=True)
|
181 |
+
loss = color_loss1 + color_loss2 + restor_loss
|
182 |
+
loss.backward()
|
183 |
+
optimizer.step()
|
184 |
+
|
185 |
+
epoch_loss.append(float(loss))
|
186 |
+
|
187 |
+
progress_bar.set_description(
|
188 |
+
'Step: {}. Epoch: {}/{}. Iteration: {}/{}. color_loss1: {:1.5f}, color_loss2: {:1.5f}, restor_loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
189 |
+
step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch,
|
190 |
+
color_loss1.item(), color_loss2.item(),
|
191 |
+
restor_loss.item(), psnr, ssim))
|
192 |
+
writer.add_scalar('Loss/train', loss, step)
|
193 |
+
writer.add_scalar('PSNR/train', psnr, step)
|
194 |
+
writer.add_scalar('SSIM/train', ssim, step)
|
195 |
+
|
196 |
+
# log learning_rate
|
197 |
+
current_lr = optimizer.param_groups[0]['lr']
|
198 |
+
writer.add_scalar('learning_rate', current_lr, step)
|
199 |
+
|
200 |
+
step += 1
|
201 |
+
|
202 |
+
except Exception as e:
|
203 |
+
print('[Error]', traceback.format_exc())
|
204 |
+
print(e)
|
205 |
+
continue
|
206 |
+
# scheduler.step(np.mean(epoch_loss))
|
207 |
+
|
208 |
+
if opt.no_sche:
|
209 |
+
scheduler.step()
|
210 |
+
|
211 |
+
saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
|
212 |
+
|
213 |
+
if epoch % opt.val_interval == 0:
|
214 |
+
model.model_canet.eval()
|
215 |
+
loss_ls = []
|
216 |
+
psnrs = []
|
217 |
+
ssims = []
|
218 |
+
|
219 |
+
for iter, (data, target, name1, name2) in enumerate(val_generator):
|
220 |
+
with torch.no_grad():
|
221 |
+
if opt.num_gpus == 1:
|
222 |
+
data, target = data.cuda(), target.cuda()
|
223 |
+
|
224 |
+
image_out, color_loss1, color_loss2, restor_loss, \
|
225 |
+
psnr, ssim = model(data, target, training=False)
|
226 |
+
saver.save_image(data, name=os.path.splitext(name1[0])[0] + '_im1')
|
227 |
+
saver.save_image(target, name=os.path.splitext(name2[0])[0] + '_im2')
|
228 |
+
saver.save_image(image_out, name=os.path.splitext(name2[0])[0] + '_im2_pred')
|
229 |
+
|
230 |
+
loss = restor_loss + color_loss1 + color_loss2
|
231 |
+
loss_ls.append(loss.item())
|
232 |
+
psnrs.append(psnr)
|
233 |
+
ssims.append(ssim)
|
234 |
+
|
235 |
+
loss = np.mean(np.array(loss_ls))
|
236 |
+
psnr = np.mean(np.array(psnrs))
|
237 |
+
ssim = np.mean(np.array(ssims))
|
238 |
+
|
239 |
+
print(
|
240 |
+
'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
241 |
+
epoch, opt.num_epochs, loss, psnr, ssim))
|
242 |
+
writer.add_scalar('Loss/val', loss, step)
|
243 |
+
writer.add_scalar('PSNR/val', psnr, step)
|
244 |
+
writer.add_scalar('SSIM/val', ssim, step)
|
245 |
+
|
246 |
+
save_checkpoint(model, f'{opt.model}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
|
247 |
+
|
248 |
+
model.model_canet.train()
|
249 |
+
|
250 |
+
opt.test_on_start = False
|
251 |
+
if opt.sampling:
|
252 |
+
exit(0)
|
253 |
+
except KeyboardInterrupt:
|
254 |
+
save_checkpoint(model, f'{opt.model}_{epoch}_{step}_keyboardInterrupt.pth')
|
255 |
+
writer.close()
|
256 |
+
writer.close()
|
257 |
+
|
258 |
+
|
259 |
+
def save_checkpoint(model, name):
|
260 |
+
if isinstance(model, nn.DataParallel):
|
261 |
+
torch.save(model.module.model_canet.state_dict(), os.path.join(opt.saved_path, name))
|
262 |
+
else:
|
263 |
+
torch.save(model.model_canet.state_dict(), os.path.join(opt.saved_path, name))
|
264 |
+
|
265 |
+
|
266 |
+
if __name__ == '__main__':
|
267 |
+
opt = get_args()
|
268 |
+
train(opt)
|
train_NFM.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import kornia
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from tqdm.autonotebook import tqdm
|
13 |
+
|
14 |
+
import models
|
15 |
+
from datasets import LowLightDataset, LowLightFDataset
|
16 |
+
from models import PSNR, SSIM, CosineLR
|
17 |
+
from tools import SingleSummaryWriter
|
18 |
+
from tools import saver, mutils
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser('Breaking Downing the Darkness')
|
23 |
+
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used')
|
24 |
+
parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader')
|
25 |
+
parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices')
|
26 |
+
parser.add_argument('-m1', '--model1', type=str, default='INet',
|
27 |
+
help='Model1 Name')
|
28 |
+
parser.add_argument('-m2', '--model2', type=str, default='NSNet',
|
29 |
+
help='Model2 Name')
|
30 |
+
parser.add_argument('-m3', '--model3', type=str, default='NSNet',
|
31 |
+
help='Model3 Name')
|
32 |
+
|
33 |
+
parser.add_argument('-m1w', '--model1_weight', type=str, default=None,
|
34 |
+
help='Model Name')
|
35 |
+
parser.add_argument('-m2w', '--model2_weight', type=str, default=None,
|
36 |
+
help='Model Name')
|
37 |
+
|
38 |
+
parser.add_argument('--comment', type=str, default='default',
|
39 |
+
help='Project comment')
|
40 |
+
parser.add_argument('--graph', action='store_true')
|
41 |
+
parser.add_argument('--no_sche', action='store_true')
|
42 |
+
parser.add_argument('--sampling', action='store_true')
|
43 |
+
|
44 |
+
parser.add_argument('--slope', type=float, default=2.)
|
45 |
+
parser.add_argument('--lr', type=float, default=1e-4)
|
46 |
+
parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, '
|
47 |
+
'suggest using \'admaw\' until the'
|
48 |
+
' very final stage then switch to \'sgd\'')
|
49 |
+
parser.add_argument('--num_epochs', type=int, default=500)
|
50 |
+
parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
|
51 |
+
parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
|
52 |
+
parser.add_argument('--data_path', type=str, default='./data/LOL',
|
53 |
+
help='the root folder of dataset')
|
54 |
+
parser.add_argument('--log_path', type=str, default='logs/')
|
55 |
+
parser.add_argument('--saved_path', type=str, default='logs/')
|
56 |
+
args = parser.parse_args()
|
57 |
+
return args
|
58 |
+
|
59 |
+
|
60 |
+
class ModelNSNet(nn.Module):
|
61 |
+
def __init__(self, model1, model2, model3):
|
62 |
+
super().__init__()
|
63 |
+
self.texture_loss = models.SSIML1Loss(channels=1)
|
64 |
+
self.model_ianet = model1(in_channels=1, out_channels=1)
|
65 |
+
self.model_nsnet = model2(in_channels=2, out_channels=1)
|
66 |
+
self.model_fusenet = model3(in_channels=3, out_channels=1)
|
67 |
+
|
68 |
+
assert opt.model1_weight is not None
|
69 |
+
self.load_weight(self.model_ianet, opt.model1_weight)
|
70 |
+
self.load_weight(self.model_nsnet, opt.model2_weight)
|
71 |
+
self.model_ianet.eval()
|
72 |
+
self.model_nsnet.eval()
|
73 |
+
self.eps = 1e-2
|
74 |
+
|
75 |
+
def load_weight(self, model, weight_pth):
|
76 |
+
state_dict = torch.load(weight_pth)
|
77 |
+
ret = model.load_state_dict(state_dict, strict=True)
|
78 |
+
print(ret)
|
79 |
+
|
80 |
+
def noise_syn(self, illumi, strength):
|
81 |
+
return torch.exp(-illumi) * strength
|
82 |
+
|
83 |
+
def forward(self, image, image_gt, training=True):
|
84 |
+
texture_nss = []
|
85 |
+
with torch.no_grad():
|
86 |
+
if training:
|
87 |
+
image = image.squeeze(0)
|
88 |
+
image_gt = image_gt.repeat(8, 1, 1, 1)
|
89 |
+
|
90 |
+
texture_in, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)
|
91 |
+
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1)
|
92 |
+
|
93 |
+
texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
|
94 |
+
illumi = self.model_ianet(texture_in_down)
|
95 |
+
illumi = F.interpolate(illumi, scale_factor=2, mode='bicubic', align_corners=True)
|
96 |
+
noisy_gt = texture_in / torch.clamp_min(illumi, self.eps)
|
97 |
+
|
98 |
+
for strength in [0, 0.05, 0.1]:
|
99 |
+
illumi = torch.clamp(illumi, 0., 1.)
|
100 |
+
attention = self.noise_syn(illumi, strength=strength)
|
101 |
+
texture_res = self.model_nsnet(torch.cat([noisy_gt, attention], dim=1))
|
102 |
+
texture_ns = noisy_gt + texture_res
|
103 |
+
texture_nss.append(texture_ns)
|
104 |
+
|
105 |
+
texture_nss = torch.cat(texture_nss, dim=1).detach()
|
106 |
+
|
107 |
+
texture_fuse = self.model_fusenet(texture_nss)
|
108 |
+
restor_loss = self.texture_loss(texture_fuse, texture_gt)
|
109 |
+
psnr = PSNR(texture_fuse, texture_gt)
|
110 |
+
ssim = SSIM(texture_fuse, texture_gt).item()
|
111 |
+
return noisy_gt, texture_nss, texture_fuse, texture_res, illumi, restor_loss, psnr, ssim
|
112 |
+
|
113 |
+
|
114 |
+
def train(opt):
|
115 |
+
if torch.cuda.is_available():
|
116 |
+
torch.cuda.manual_seed(42)
|
117 |
+
else:
|
118 |
+
torch.manual_seed(42)
|
119 |
+
|
120 |
+
timestamp = mutils.get_formatted_time()
|
121 |
+
opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}'
|
122 |
+
opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/'
|
123 |
+
os.makedirs(opt.log_path, exist_ok=True)
|
124 |
+
os.makedirs(opt.saved_path, exist_ok=True)
|
125 |
+
|
126 |
+
training_params = {'batch_size': opt.batch_size,
|
127 |
+
'shuffle': True,
|
128 |
+
'drop_last': True,
|
129 |
+
'num_workers': opt.num_workers}
|
130 |
+
|
131 |
+
val_params = {'batch_size': 1,
|
132 |
+
'shuffle': False,
|
133 |
+
'drop_last': True,
|
134 |
+
'num_workers': opt.num_workers}
|
135 |
+
|
136 |
+
training_set = LowLightFDataset(os.path.join(opt.data_path, 'train'), image_split='images_aug')
|
137 |
+
training_generator = DataLoader(training_set, **training_params)
|
138 |
+
|
139 |
+
val_set = LowLightDataset(os.path.join(opt.data_path, 'eval'))
|
140 |
+
val_generator = DataLoader(val_set, **val_params)
|
141 |
+
|
142 |
+
model1 = getattr(models, opt.model1)
|
143 |
+
model2 = getattr(models, opt.model2)
|
144 |
+
model3 = getattr(models, opt.model3)
|
145 |
+
writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
|
146 |
+
|
147 |
+
model = ModelNSNet(model1, model2, model3)
|
148 |
+
print(model)
|
149 |
+
|
150 |
+
if opt.num_gpus > 0:
|
151 |
+
model = model.cuda()
|
152 |
+
if opt.num_gpus > 1:
|
153 |
+
model = nn.DataParallel(model)
|
154 |
+
|
155 |
+
if opt.optim == 'adam':
|
156 |
+
optimizer = torch.optim.Adam(model.model_fusenet.parameters(), opt.lr)
|
157 |
+
else:
|
158 |
+
optimizer = torch.optim.SGD(model.model_fusenet.parameters(), opt.lr, momentum=0.9, nesterov=True)
|
159 |
+
|
160 |
+
scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs)
|
161 |
+
epoch = 0
|
162 |
+
step = 0
|
163 |
+
model.model_fusenet.train()
|
164 |
+
|
165 |
+
num_iter_per_epoch = len(training_generator)
|
166 |
+
|
167 |
+
try:
|
168 |
+
for epoch in range(opt.num_epochs):
|
169 |
+
last_epoch = step // num_iter_per_epoch
|
170 |
+
if epoch < last_epoch:
|
171 |
+
continue
|
172 |
+
|
173 |
+
epoch_loss = []
|
174 |
+
progress_bar = tqdm(training_generator)
|
175 |
+
|
176 |
+
saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch)
|
177 |
+
if not opt.sampling:
|
178 |
+
for iter, (data, target, name) in enumerate(progress_bar):
|
179 |
+
if iter < step - last_epoch * num_iter_per_epoch:
|
180 |
+
progress_bar.update()
|
181 |
+
continue
|
182 |
+
try:
|
183 |
+
if opt.num_gpus == 1:
|
184 |
+
data = data.cuda()
|
185 |
+
target = target.cuda()
|
186 |
+
|
187 |
+
optimizer.zero_grad()
|
188 |
+
|
189 |
+
noisy_gt, texture_nss, texture_fuse, texture_res, \
|
190 |
+
illumi, restor_loss, psnr, ssim = model(data, target, training=True)
|
191 |
+
|
192 |
+
loss = restor_loss
|
193 |
+
loss.backward()
|
194 |
+
optimizer.step()
|
195 |
+
|
196 |
+
epoch_loss.append(float(loss))
|
197 |
+
|
198 |
+
progress_bar.set_description(
|
199 |
+
'Step: {}. Epoch: {}/{}. Iteration: {}/{}. restor_loss: {:.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
200 |
+
step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, restor_loss.item(), psnr,
|
201 |
+
ssim))
|
202 |
+
writer.add_scalar('Loss/train', loss, step)
|
203 |
+
writer.add_scalar('PSNR/train', psnr, step)
|
204 |
+
writer.add_scalar('SSIM/train', ssim, step)
|
205 |
+
|
206 |
+
# log learning_rate
|
207 |
+
current_lr = optimizer.param_groups[0]['lr']
|
208 |
+
writer.add_scalar('learning_rate', current_lr, step)
|
209 |
+
|
210 |
+
step += 1
|
211 |
+
|
212 |
+
except Exception as e:
|
213 |
+
print('[Error]', traceback.format_exc())
|
214 |
+
print(e)
|
215 |
+
continue
|
216 |
+
|
217 |
+
if not opt.no_sche:
|
218 |
+
scheduler.step()
|
219 |
+
|
220 |
+
if epoch % opt.val_interval == 0:
|
221 |
+
model.model_fusenet.eval()
|
222 |
+
loss_ls = []
|
223 |
+
psnrs = []
|
224 |
+
ssims = []
|
225 |
+
|
226 |
+
for iter, (data, target, name) in enumerate(val_generator):
|
227 |
+
with torch.no_grad():
|
228 |
+
if opt.num_gpus == 1:
|
229 |
+
data = data.cuda()
|
230 |
+
target = target.cuda()
|
231 |
+
|
232 |
+
noisy_gt, texture_nss, texture_fuse, texture_res, \
|
233 |
+
illumi, restor_loss, psnr, ssim = model(data, target, training=False)
|
234 |
+
texture_gt, _, _ = torch.split(kornia.color.rgb_to_ycbcr(target), 1, dim=1)
|
235 |
+
|
236 |
+
saver.save_image(noisy_gt, name=os.path.splitext(name[0])[0] + '_in')
|
237 |
+
saver.save_image(texture_nss.transpose(0, 1), name=os.path.splitext(name[0])[0] + '_ns')
|
238 |
+
saver.save_image(texture_fuse, name=os.path.splitext(name[0])[0] + '_fuse')
|
239 |
+
saver.save_image(texture_res, name=os.path.splitext(name[0])[0] + '_res')
|
240 |
+
saver.save_image(illumi, name=os.path.splitext(name[0])[0] + '_ill')
|
241 |
+
saver.save_image(target, name=os.path.splitext(name[0])[0] + '_gt')
|
242 |
+
|
243 |
+
loss = restor_loss
|
244 |
+
loss_ls.append(loss.item())
|
245 |
+
psnrs.append(psnr)
|
246 |
+
ssims.append(ssim)
|
247 |
+
|
248 |
+
loss = np.mean(np.array(loss_ls))
|
249 |
+
psnr = np.mean(np.array(psnrs))
|
250 |
+
ssim = np.mean(np.array(ssims))
|
251 |
+
|
252 |
+
print(
|
253 |
+
'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format(
|
254 |
+
epoch, opt.num_epochs, loss, psnr, ssim))
|
255 |
+
writer.add_scalar('Loss/val', loss, step)
|
256 |
+
writer.add_scalar('PSNR/val', psnr, step)
|
257 |
+
writer.add_scalar('SSIM/val', ssim, step)
|
258 |
+
|
259 |
+
save_checkpoint(model, f'{opt.model3}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth')
|
260 |
+
|
261 |
+
model.model_fusenet.train()
|
262 |
+
if opt.sampling:
|
263 |
+
exit(0)
|
264 |
+
except KeyboardInterrupt:
|
265 |
+
save_checkpoint(model, f'{opt.model3}_{epoch}_{step}_keyboardInterrupt.pth')
|
266 |
+
writer.close()
|
267 |
+
writer.close()
|
268 |
+
|
269 |
+
|
270 |
+
def save_checkpoint(model, name):
|
271 |
+
if isinstance(model, nn.DataParallel):
|
272 |
+
torch.save(model.module.model_fusenet.state_dict(), os.path.join(opt.saved_path, name))
|
273 |
+
else:
|
274 |
+
torch.save(model.model_fdnet.state_dict(), os.path.join(opt.saved_path, name))
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == '__main__':
|
278 |
+
opt = get_args()
|
279 |
+
train(opt)
|