vict0rsch commited on
Commit
ce190ee
1 Parent(s): c9f93f9

initial commit from `vict0rsch/climateGAN`

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +143 -0
  3. Contributing.md +56 -0
  4. LICENSE +674 -0
  5. README.md +187 -6
  6. USAGE.md +328 -0
  7. app.py +70 -0
  8. apply_events.py +642 -0
  9. climategan/__init__.py +9 -0
  10. climategan/blocks.py +398 -0
  11. climategan/bn_fusion.py +137 -0
  12. climategan/data.py +539 -0
  13. climategan/deeplab/__init__.py +101 -0
  14. climategan/deeplab/deeplab_v2.py +198 -0
  15. climategan/deeplab/deeplab_v3.py +271 -0
  16. climategan/deeplab/mobilenet_v3.py +324 -0
  17. climategan/deeplab/resnet101_v3.py +203 -0
  18. climategan/deeplab/resnetmulti_v2.py +136 -0
  19. climategan/depth.py +230 -0
  20. climategan/discriminator.py +361 -0
  21. climategan/eval_metrics.py +635 -0
  22. climategan/fid.py +561 -0
  23. climategan/fire.py +133 -0
  24. climategan/generator.py +415 -0
  25. climategan/logger.py +445 -0
  26. climategan/losses.py +620 -0
  27. climategan/masker.py +234 -0
  28. climategan/norms.py +186 -0
  29. climategan/optim.py +291 -0
  30. climategan/painter.py +171 -0
  31. climategan/strings.py +99 -0
  32. climategan/trainer.py +1939 -0
  33. climategan/transforms.py +626 -0
  34. climategan/tutils.py +721 -0
  35. climategan/utils.py +1063 -0
  36. config/model/masker/.ipynb_checkpoints/opts-checkpoint.yaml +3 -0
  37. config/model/masker/opts.yaml +3 -0
  38. config/model/painter/opts.yaml +3 -0
  39. eval_masker.py +796 -0
  40. figures/ablation_comparison.py +394 -0
  41. figures/bootstrap_ablation.py +562 -0
  42. figures/bootstrap_ablation_summary.py +361 -0
  43. figures/human_evaluation.py +208 -0
  44. figures/labels.py +200 -0
  45. figures/metrics.py +676 -0
  46. figures/metrics_onefig.py +772 -0
  47. inferences.py +108 -0
  48. requirements-3.8.2.txt +91 -0
  49. requirements-any.txt +20 -0
  50. sbatch.py +933 -0
.gitattributes CHANGED
@@ -31,3 +31,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ config/model/** filter=lfs diff=lfs merge=lfs -text
35
+ notebooks/** filter=lfs diff=lfs merge=lfs -text
36
+ images/** filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omnienv/
2
+ example_data/
3
+ .vscode/
4
+ .comet.config
5
+ .DS_Store
6
+ config/
7
+ tests/not_committed/
8
+ *.hydra
9
+ outputs/
10
+ eval_folder*
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ pip-wheel-metadata/
35
+ share/python-wheels/
36
+ *.egg-info/
37
+ .installed.cfg
38
+ *.egg
39
+ MANIFEST
40
+
41
+ # PyInstaller
42
+ # Usually these files are written by a python script from a template
43
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
44
+ *.manifest
45
+ *.spec
46
+
47
+ # Installer logs
48
+ pip-log.txt
49
+ pip-delete-this-directory.txt
50
+
51
+ # Unit test / coverage reports
52
+ htmlcov/
53
+ .tox/
54
+ .nox/
55
+ .coverage
56
+ .coverage.*
57
+ .cache
58
+ nosetests.xml
59
+ coverage.xml
60
+ *.cover
61
+ *.py,cover
62
+ .hypothesis/
63
+ .pytest_cache/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
141
+
142
+ # local visualize tool
143
+ visualizedEval.py
Contributing.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1. Understand the file structure:
2
+ 1. architectures in `discriminator.py` `generator.py` `classifier.py`
3
+ 2. data loading in `data.py`
4
+ 3. data transformation `transforms.py`
5
+ 4. optimizers in `optim.py`
6
+ 5. utilities in `utils.py`
7
+ 6. training procedure in `trainer.py`
8
+ 2. Write **tests** in `tests/`
9
+ 1. your file should match `test_*.py`
10
+ 2. update existing tests when adding functionalities
11
+ 3. run tests regularly to check you haven't broken anything `python tests/run.py`
12
+ 3. Add **WIP** in your PR's title when not ready to merge
13
+ 5. Open an Issue if something's odd, or to assign yourself a todo
14
+ 6. **Format your code** with [black](https://github.com/psf/black)
15
+ 7. Only update `trainer/defaults.yaml` with values that should be shared across runs and users
16
+ 1. use `config/trainer/local_tests.yaml` or any other to setup your particular config overriding `trainer/defaults.yaml`
17
+
18
+ ## Running tests
19
+
20
+ As per `7.` you should set your particular config in `config/local_tests.yaml`. Mine looks like:
21
+
22
+ ```yaml
23
+ output_path: /Users/victor/Documents/ccai/github/climategan/example_data
24
+ # -------------------
25
+ # ----- Tasks -----
26
+ # -------------------
27
+ #tasks: [a, d, h, s, t, w]
28
+ tasks: [a, d, s, t] # for now no h or w
29
+ # ----------------
30
+ # ----- Data -----
31
+ # ----------------
32
+ data:
33
+ files: # if one is not none it will override the dirs location
34
+ base: /Users/victor/Documents/ccai/github/climategan/example_data
35
+ transforms:
36
+ - name: hflip
37
+ ignore: false
38
+ p: 0.5
39
+ - name: resize
40
+ ignore: false
41
+ new_size: 256
42
+ - name: crop
43
+ ignore: false
44
+ height: 64
45
+ width: 64
46
+ gen:
47
+ encoder:
48
+ n_res: 1
49
+ default:
50
+ n_res: 1
51
+
52
+ train:
53
+ log_level: 1
54
+ ```
55
+
56
+ Setting `n_res` to 1 is important to run tests faster and with less memory
LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
README.md CHANGED
@@ -1,13 +1,194 @@
1
  ---
 
 
 
 
 
 
 
2
  title: ClimateGAN
3
- emoji: 🏢
4
- colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.6
8
  app_file: app.py
9
- pinned: false
10
- license: gpl-3.0
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - Climate Change
6
+ - GAN
7
+ - Domain Adaptation
8
+ license: gpl-3.0
9
  title: ClimateGAN
10
+ emoji: 🌎
11
+ colorFrom: blue
12
  colorTo: green
13
  sdk: gradio
14
+ sdk_version: 4.6
15
  app_file: app.py
16
+ inference: true
17
+ # datasets:
18
+ # -
19
  ---
20
 
21
+ # ClimateGAN: Raising Awareness about Climate Change by Generating Images of Floods
22
+
23
+ This repository contains the code used to train the model presented in our **[paper](https://openreview.net/forum?id=EZNOb_uNpJk)**.
24
+
25
+ It is not simply a presentation repository but the code we have used over the past 30 months to come to our final architecture. As such, you will find many scripts, classes, blocks and options which we actively use for our own development purposes but are not directly relevant to reproduce results or use pretrained weights.
26
+
27
+ ![flood processing](images/flood.png)
28
+
29
+ If you use this code, data or pre-trained weights, please cite our ICLR 2022 paper:
30
+
31
+ ```
32
+ @inproceedings{schmidt2022climategan,
33
+ title = {Climate{GAN}: Raising Climate Change Awareness by Generating Images of Floods},
34
+ author = {Victor Schmidt and Alexandra Luccioni and M{\'e}lisande Teng and Tianyu Zhang and Alexia Reynaud and Sunand Raghupathi and Gautier Cosne and Adrien Juraver and Vahe Vardanyan and Alex Hern{\'a}ndez-Garc{\'\i}a and Yoshua Bengio},
35
+ booktitle = {International Conference on Learning Representations},
36
+ year = {2022},
37
+ url = {https://openreview.net/forum?id=EZNOb_uNpJk}
38
+ }
39
+ ```
40
+
41
+ ## Using pre-trained weights
42
+
43
+ In the paper, we present ClimateGAN as a solution to produce images of floods. It can actually do **more**:
44
+
45
+ * reusing the segmentation map, we are able to isolate the sky, turn it red and in a few more steps create an image resembling the consequences of a wildfire on a neighboring area, similarly to the [California wildfires](https://www.google.com/search?q=california+wildfires+red+sky&source=lnms&tbm=isch&sa=X&ved=2ahUKEwisws-hx7zxAhXxyYUKHQyKBUwQ_AUoAXoECAEQBA&biw=1680&bih=917&dpr=2).
46
+ * reusing the depth map, we can simulate the consequences of a smog event on an image, scaling the intensity of the filter by the distance of an object to the camera, as per [HazeRD](http://www2.ece.rochester.edu/~gsharma/papers/Zhang_ICIP2017_HazeRD.pdf)
47
+
48
+ ![image of wildfire processing](images/wildfire.png)
49
+ ![image of smog processing](images/smog.png)
50
+
51
+ In this section we'll explain how to produce the `Painted Input` along with the Smog and Wildfire outputs of a pre-trained ClimateGAN model.
52
+
53
+ ### Installation
54
+
55
+ This repository and associated model have been developed using Python 3.8.2 and **Pytorch 1.7.0**.
56
+
57
+ ```bash
58
+ $ git clone git@github.com:cc-ai/climategan.git
59
+ $ cd climategan
60
+ $ pip install -r requirements-3.8.2.txt # or `requirements-any.txt` for other Python versions (not tested but expected to be fine)
61
+ ```
62
+
63
+ Our pipeline uses [comet.ml](https://comet.ml) to log images. You don't *have* to use their services but we recommend you do as images can be uploaded on your workspace instead of being written to disk.
64
+
65
+ If you want to use Comet, make sure you have the [appropriate configuration in place (API key and workspace at least)](https://www.comet.ml/docs/python-sdk/advanced/#non-interactive-setup)
66
+
67
+ ### Inference
68
+
69
+ 1. Download and unzip the weights [from this link](https://drive.google.com/u/0/uc?id=18OCUIy7JQ2Ow_-cC5xn_hhDn-Bp45N1K&export=download) (checkout [`gdown`](https://github.com/wkentaro/gdown) for a commandline interface) and put them in `config/`
70
+
71
+ ```
72
+ $ pip install gdown
73
+ $ mkdir config
74
+ $ cd config
75
+ $ gdown https://drive.google.com/u/0/uc?id=18OCUIy7JQ2Ow_-cC5xn_hhDn-Bp45N1K
76
+ $ unzip release-github-v1.zip
77
+ $ cd ..
78
+ ```
79
+
80
+ 2. Run from the repo's root:
81
+
82
+ 1. With `comet`:
83
+
84
+ ```bash
85
+ python apply_events.py --batch_size 4 --half --images_paths path/to/a/folder --resume_path config/model/masker --upload
86
+ ```
87
+
88
+ 2. Without `comet` (and shortened args compared to the previous example):
89
+
90
+ ```bash
91
+ python apply_events.py -b 4 --half -i path/to/a/folder -r config/model/masker --output_path path/to/a/folder
92
+ ```
93
+
94
+ The `apply_events.py` script has many options, for instance to use a different output size than the default systematic `640 x 640` pixels, look at the code or `python apply_events.py --help`.
95
+
96
+ ## Training from scratch
97
+
98
+ ClimateGAN is split in two main components: the Masker producing a binary mask of where water should go and the Painter generating water within this mask given an initial image's context.
99
+
100
+ ### Configuration
101
+
102
+ The code is structured to use `shared/trainer/defaults.yaml` as default configuration. There are 2 ways of overriding those for your purposes (without altering that file):
103
+
104
+ 1. By providing an alternative configuration as command line argument `config=path/to/config.yaml`
105
+
106
+ 1. The code will first load `shared/trainer/defaults.yaml`
107
+ 2. *then* update the resulting dictionary with values read in the provided `config` argument.
108
+ 3. The folder `config/` is NOT tracked by git so you would typically put them there
109
+
110
+ 2. By overwriting specific arguments from the command-line like `python train.py data.loaders.batch_size=8`
111
+
112
+
113
+ ### Data
114
+
115
+ #### Masker
116
+
117
+ ##### Real Images
118
+
119
+ Because of copyrights issues we are not able to share the real images scrapped from the internet. You would have to do that yourself. In the `yaml` config file, the code expects a key pointing to a `json` file like `data.files.<train or val>.r: <path/to/a/json/file>`. This `json` file should be a list of dictionaries with tasks as keys and files as values. Example:
120
+
121
+ ```json
122
+ [
123
+ {
124
+ "x": "path/to/a/real/image",
125
+ "s": "path/to/a/segmentation_map",
126
+ "d": "path/to/a/depth_map"
127
+ },
128
+ ...
129
+ ]
130
+ ```
131
+
132
+ Following the [ADVENT](https://github.com/valeoai/ADVENT) procedure, only `x` should be required. We use `s` and `d` inferred from pre-trained models (DeepLab v3+ and MiDAS) to use those pseudo-labels in the first epochs of training (see `pseudo:` in the config file)
133
+
134
+ ##### Simulated Images
135
+
136
+ We share snapshots of the Virtual World we created in the [Mila-Simulated-Flood dataset](). You can download and unzip one water-level and then produce json files similar to that of the real data, with an additional key `"m": "path/to/a/ground_truth_sim_mask"`. Lastly, edit the config file: `data.files.<train or val>.s: <path/to/a/json/file>`
137
+
138
+ #### Painter
139
+
140
+ The painter expects input images and binary masks to train using the [GauGAN](https://github.com/NVlabs/SPADE) training procedure. Unfortunately we cannot share openly the collected data, but similarly as for the Masker's real data you would point to the data using a `json` file as:
141
+
142
+ ```json
143
+ [
144
+ {
145
+ "x": "path/to/a/real/image",
146
+ "m": "path/to/a/water_mask",
147
+ },
148
+ ...
149
+ ]
150
+ ```
151
+
152
+ And put those files as values to `data.files.<train or val>.rf: <path/to/a/json/file>` in the configuration.
153
+
154
+ ## Coding conventions
155
+
156
+ * Tasks
157
+ * `x` is an input image, in [-1, 1]
158
+ * `s` is a segmentation target with `long` classes
159
+ * `d` is a depth map target in R, may be actually `log(depth)` or `1/depth`
160
+ * `m` is a binary mask with 1s where water is/should be
161
+ * Domains
162
+ * `r` is the *real* domain for the masker. Input images are real pictures of urban/suburban/rural areas
163
+ * `s` is the *simulated* domain for the masker. Input images are taken from our Unity world
164
+ * `rf` is the *real flooded* domain for the painter. Training images are pairs `(x, m)` of flooded scenes for which the water should be reconstructed, in the validation data input images are not flooded and we provide a manually labeled mask `m`
165
+ * `kitti` is a special `s` domain to pre-train the masker on [Virtual Kitti 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/)
166
+ * it alters the `trainer.loaders` dict to select relevant data sources from `trainer.all_loaders` in `trainer.switch_data()`. The rest of the code is identical.
167
+ * Flow
168
+ * This describes the call stack for the trainers standard training procedure
169
+ * `train()`
170
+ * `run_epoch()`
171
+ * `update_G()`
172
+ * `zero_grad(G)`
173
+ * `get_G_loss()`
174
+ * `get_masker_loss()`
175
+ * `masker_m_loss()` -> masking loss
176
+ * `masker_s_loss()` -> segmentation loss
177
+ * `masker_d_loss()` -> depth estimation loss
178
+ * `get_painter_loss()` -> painter's loss
179
+ * `g_loss.backward()`
180
+ * `g_opt_step()`
181
+ * `update_D()`
182
+ * `zero_grad(D)`
183
+ * `get_D_loss()`
184
+ * painter's disc losses
185
+ * `masker_m_loss()` -> masking AdvEnt disc loss
186
+ * `masker_s_loss()` -> segmentation AdvEnt disc loss
187
+ * `d_loss.backward()`
188
+ * `d_opt_step()`
189
+ * `update_learning_rates()` -> update learning rates according to schedules defined in `opts.gen.opt` and `opts.dis.opt`
190
+ * `run_validation()`
191
+ * compute val losses
192
+ * `eval_images()` -> compute metrics
193
+ * `log_comet_images()` -> compute and upload inferences
194
+ * `save()`
USAGE.md ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ClimateGAN
2
+ - [ClimateGAN](#climategan)
3
+ - [Setup](#setup)
4
+ - [Coding conventions](#coding-conventions)
5
+ - [updates](#updates)
6
+ - [interfaces](#interfaces)
7
+ - [Logging on comet](#logging-on-comet)
8
+ - [Resources](#resources)
9
+ - [Example](#example)
10
+ - [Release process](#release-process)
11
+
12
+ ## Setup
13
+
14
+ **`PyTorch >= 1.1.0`** otherwise optimizer.step() and scheduler.step() are in the wrong order ([docs](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate))
15
+
16
+ **pytorch==1.6** to use pytorch-xla or automatic mixed precision (`amp` branch).
17
+
18
+ Configuration files use the **YAML** syntax. If you don't know what `&` and `<<` mean, you'll have a hard time reading the files. Have a look at:
19
+
20
+ * https://dev.to/paulasantamaria/introduction-to-yaml-125f
21
+ * https://stackoverflow.com/questions/41063361/what-is-the-double-left-arrow-syntax-in-yaml-called-and-wheres-it-specced/41065222
22
+
23
+ **pip**
24
+
25
+ ```
26
+ $ pip install comet_ml scipy opencv-python torch torchvision omegaconf==1.4.1 hydra-core==0.11.3 scikit-image imageio addict tqdm torch_optimizer
27
+ ```
28
+
29
+ ## Coding conventions
30
+
31
+ * Tasks
32
+ * `x` is an input image, in [-1, 1]
33
+ * `s` is a segmentation target with `long` classes
34
+ * `d` is a depth map target in R, may be actually `log(depth)` or `1/depth`
35
+ * `m` is a binary mask with 1s where water is/should be
36
+ * Domains
37
+ * `r` is the *real* domain for the masker. Input images are real pictures of urban/suburban/rural areas
38
+ * `s` is the *simulated* domain for the masker. Input images are taken from our Unity world
39
+ * `rf` is the *real flooded* domain for the painter. Training images are pairs `(x, m)` of flooded scenes for which the water should be reconstructed, in the validation data input images are not flooded and we provide a manually labeled mask `m`
40
+ * `kitti` is a special `s` domain to pre-train the masker on [Virtual Kitti 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/)
41
+ * it alters the `trainer.loaders` dict to select relevant data sources from `trainer.all_loaders` in `trainer.switch_data()`. The rest of the code is identical.
42
+ * Flow
43
+ * This describes the call stack for the trainers standard training procedure
44
+ * `train()`
45
+ * `run_epoch()`
46
+ * `update_G()`
47
+ * `zero_grad(G)`
48
+ * `get_G_loss()`
49
+ * `get_masker_loss()`
50
+ * `masker_m_loss()` -> masking loss
51
+ * `masker_s_loss()` -> segmentation loss
52
+ * `masker_d_loss()` -> depth estimation loss
53
+ * `get_painter_loss()` -> painter's loss
54
+ * `g_loss.backward()`
55
+ * `g_opt_step()`
56
+ * `update_D()`
57
+ * `zero_grad(D)`
58
+ * `get_D_loss()`
59
+ * painter's disc losses
60
+ * `masker_m_loss()` -> masking AdvEnt disc loss
61
+ * `masker_s_loss()` -> segmentation AdvEnt disc loss
62
+ * `d_loss.backward()`
63
+ * `d_opt_step()`
64
+ * `update_learning_rates()` -> update learning rates according to schedules defined in `opts.gen.opt` and `opts.dis.opt`
65
+ * `run_validation()`
66
+ * compute val losses
67
+ * `eval_images()` -> compute metrics
68
+ * `log_comet_images()` -> compute and upload inferences
69
+ * `save()`
70
+
71
+ ### Resuming
72
+
73
+ Set `train.resume` to `True` in `opts.yaml` and specify where to load the weights:
74
+
75
+ Use a config's `load_path` namespace. It should have sub-keys `m`, `p` and `pm`:
76
+
77
+ ```yaml
78
+ load_paths:
79
+ p: none # Painter weights
80
+ m: none # Masker weights
81
+ pm: none # Painter + Masker weights (single ckpt for both)
82
+ ```
83
+
84
+ 1. any path which leads to a dir will be loaded as `path / checkpoints / latest_ckpt.pth`
85
+ 2. if you want to specify a specific checkpoint (not the latest), it MUST be a `.pth` file
86
+ 3. resuming a `P` **OR** an `M` model, you may only specify 1 of `load_path.p` **OR** `load_path.m`.
87
+ You may also leave **BOTH** at `none`, in which case `output_path / checkpoints / latest_ckpt.pth`
88
+ will be used
89
+ 4. resuming a P+M model, you may specify (`p` AND `m`) **OR** `pm` **OR** leave all at `none`,
90
+ in which case `output_path / checkpoints / latest_ckpt.pth` will be used to load from
91
+ a single checkpoint
92
+
93
+ ### Generator
94
+
95
+ * **Encoder**:
96
+
97
+ `trainer.G.encoder` Deeplabv2 or v3-based encoder
98
+ * Code borrowed from
99
+ * https://github.com/valeoai/ADVENT/blob/master/advent/model/deeplabv2.py
100
+ * https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes
101
+
102
+ * **Decoders**:
103
+ * `trainer.G.decoders["s"]` -> *Segmentation* -> DLV3+ architecture (ASPP + Decoder)
104
+ * `trainer.G.decoders["d"]` -> *Depth* -> ResBlocks + (Upsample + Conv)
105
+ * `trainer.G.decoders["m"]` -> *Mask* -> ResBlocks + (Upsample + Conv) -> Binary mask: 1 = water should be there
106
+ * `trainer.G.mask()` predicts a mask and optionally applies `sigmoid` from an `x` input or a `z` input
107
+
108
+ * **Painter**: `trainer.G.painter` -> [GauGAN SPADE-based](https://github.com/NVlabs/SPADE)
109
+ * input = masked image
110
+ * `trainer.G.paint(m, x)` higher level function which takes care of masking
111
+ * If `opts.gen.p.paste_original_content` the painter should only create water and not reconstruct outside the mask: the output of `paint()` is `painted * m + x * (1 - m)`
112
+
113
+ High level methods of interest:
114
+
115
+ * `trainer.infer_all()` creates a dictionary of events with keys `flood` `wildfire` and `smog`. Can take in a single image or a batch, of numpy arrays or torch tensors, on CPU/GPU/TPU. This method calls, amongst others:
116
+ * `trainer.G.encode()` to compute the shared latent vector `z`
117
+ * `trainer.G.mask(z=z)` to infer the mask
118
+ * `trainer.compute_fire(x, segmentation)` to create a wildfire image from `x` and inferred segmentation
119
+ * `trainer.compute_smog(x, depth)` to create a smog image from `x` and inferred depth
120
+ * `trainer.compute_flood(x, mask)` to create a flood image from `x` and inferred mask using the painter (`trainer.G.paint(m, x)`)
121
+ * `Trainer.resume_from_path()` static method to resume a trainer from a path
122
+
123
+ ### Discriminator
124
+
125
+ ## updates
126
+
127
+ multi-batch:
128
+
129
+ ```
130
+ multi_domain_batch = {"rf: batch0, "r": batch1, "s": batch2}
131
+ ```
132
+
133
+ ## interfaces
134
+
135
+ ### batches
136
+ ```python
137
+ batch = Dict({
138
+ "data": {
139
+ "d": depthmap,,
140
+ "s": segmentation_map,
141
+ "m": binary_mask
142
+ "x": real_flooded_image,
143
+ },
144
+ "paths":{
145
+ same_keys: path_to_file
146
+ }
147
+ "domain": list(rf | r | s),
148
+ "mode": list(train | val)
149
+ })
150
+ ```
151
+
152
+ ### data
153
+
154
+ #### json files
155
+
156
+ | name | domain | description | author |
157
+ | :--------------------------------------------- | :----: | :------------------------------------------------------------------------- | :-------: |
158
+ | **train_r_full.json, val_r_full.json** | r | MiDaS+ Segmentation pseudo-labels .pt (HRNet + Cityscapes) | Mélisande |
159
+ | **train_s_full.json, val_s_full.json** | s | Simulated data from Unity11k urban + Unity suburban dataset | *** |
160
+ | train_s_nofences.json, val_s_nofences.json | s | Simulated data from Unity11k urban + Unity suburban dataset without fences | Alexia |
161
+ | train_r_full_pl.json, val_r_full_pl.json | r | MegaDepth + Segmentation pseudo-labels .pt (HRNet + Cityscapes) | Alexia |
162
+ | train_r_full_midas.json, val_r_full_midas.json | r | MiDaS+ Segmentation (HRNet + Cityscapes) | Mélisande |
163
+ | train_r_full_old.json, val_r_full_old.json | r | MegaDepth+ Segmentation (HRNet + Cityscapes) | *** |
164
+ | train_r_nopeople.json, val_r_nopeople.json | r | Same training data as above with people removed | Sasha |
165
+ | train_rf_with_sim.json | rf | Doubled train_rf's size with sim data (randomly chosen) | Victor |
166
+ | train_rf.json | rf | UPDATE (12/12/20): added 50 ims & masks from ADE20K Outdoors | Victor |
167
+ | train_allres.json, val_allres.json | rf | includes both lowres and highres from ORCA_water_seg | Tianyu |
168
+ | train_highres_only.json, val_highres_only.json | rf | includes only highres from ORCA_water_seg | Tianyu |
169
+
170
+
171
+ ```yaml
172
+ # data file ; one for each r|s
173
+ - x: /path/to/image
174
+ m: /path/to/mask
175
+ s: /path/to/segmentation map
176
+ - x: /path/to/another image
177
+ d: /path/to/depth map
178
+ m: /path/to/mask
179
+ s: /path/to/segmentation map
180
+ - x: ...
181
+ ```
182
+
183
+ or
184
+
185
+ ```json
186
+ [
187
+ {
188
+ "x": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000005.jpg",
189
+ "s": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000005.npy",
190
+ "d": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000005_depth.jpg"
191
+ },
192
+ {
193
+ "x": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000006.jpg",
194
+ "s": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000006.npy",
195
+ "d": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000006_depth.jpg"
196
+ }
197
+ ]
198
+ ```
199
+
200
+ The json files used are located at `/network/tmp1/ccai/data/climategan/`. In the basenames, `_s` denotes simulated domain data and `_r` real domain data.
201
+ The `base` folder contains json files with paths to images (`"x"`key) and masks (taken as ground truth for the area that should be flooded, `"m"` key).
202
+ The `seg` folder contains json files and keys `"x"`, `"m"` and `"s"` (segmentation) for each image.
203
+
204
+
205
+ loaders
206
+
207
+ ```
208
+ loaders = Dict({
209
+ train: { r: loader, s: loader},
210
+ val: { r: loader, s: loader}
211
+ })
212
+ ```
213
+
214
+ ### losses
215
+
216
+ `trainer.losses` is a dictionary mapping to loss functions to optimize for the 3 main parts of the architecture: generator `G`, discriminators `D`:
217
+
218
+ ```python
219
+ trainer.losses = {
220
+ "G":{ # generator
221
+ "gan": { # gan loss from the discriminators
222
+ "a": GANLoss, # adaptation decoder
223
+ "t": GANLoss # translation decoder
224
+ },
225
+ "cycle": { # cycle-consistency loss
226
+ "a": l1 | l2,,
227
+ "t": l1 | l2,
228
+ },
229
+ "auto": { # auto-encoding loss a.k.a. reconstruction loss
230
+ "a": l1 | l2,
231
+ "t": l1 | l2
232
+ },
233
+ "tasks": { # specific losses for each auxillary task
234
+ "d": func, # depth estimation
235
+ "h": func, # height estimation
236
+ "s": cross_entropy_2d, # segmentation
237
+ "w": func, # water generation
238
+ },
239
+ "classifier": l1 | l2 | CE # loss from fooling the classifier
240
+ },
241
+ "D": GANLoss, # discriminator losses from the generator and true data
242
+ "C": l1 | l2 | CE # classifier should predict the right 1-h vector [rf, rn, sf, sn]
243
+ }
244
+ ```
245
+
246
+ ## Logging on comet
247
+
248
+ Comet.ml will look for api keys in the following order: argument to the `Experiment(api_key=...)` call, `COMET_API_KEY` environment variable, `.comet.config` file in the current working directory, `.comet.config` in the current user's home directory.
249
+
250
+ If your not managing several comet accounts at the same time, I recommend putting `.comet.config` in your home as such:
251
+
252
+ ```
253
+ [comet]
254
+ api_key=<api_key>
255
+ workspace=vict0rsch
256
+ rest_api_key=<rest_api_key>
257
+ ```
258
+
259
+ ### Tests
260
+
261
+ Run tests by executing `python test_trainer.py`. You can add `--no_delete` not to delete the comet experiment at exit and inspect uploads.
262
+
263
+ Write tests as scenarios by adding to the list `test_scenarios` in the file. A scenario is a dict of overrides over the base opts in `shared/trainer/defaults.yaml`. You can create special flags for the scenario by adding keys which start with `__`. For instance, `__doc` is a mandatory key in any scenario describing it succinctly.
264
+
265
+ ## Resources
266
+
267
+ [Tricks and Tips for Training a GAN](https://chloes-dl.com/2019/11/19/tricks-and-tips-for-training-a-gan/)
268
+ [GAN Hacks](https://github.com/soumith/ganhacks)
269
+ [Keep Calm and train a GAN. Pitfalls and Tips on training Generative Adversarial Networks](https://medium.com/@utk.is.here/keep-calm-and-train-a-gan-pitfalls-and-tips-on-training-generative-adversarial-networks-edd529764aa9)
270
+
271
+ ## Example
272
+
273
+ **Inference: computing floods**
274
+
275
+ ```python
276
+ from pathlib import Path
277
+ from skimage.io import imsave
278
+ from tqdm import tqdm
279
+
280
+ from climategan.trainer import Trainer
281
+ from climategan.utils import find_images
282
+ from climategan.tutils import tensor_ims_to_np_uint8s
283
+ from climategan.transforms import PrepareInference
284
+
285
+
286
+ model_path = "some/path/to/output/folder" # not .ckpt
287
+ input_folder = "path/to/a/folder/with/images"
288
+ output_path = "path/where/images/will/be/written"
289
+
290
+ # resume trainer
291
+ trainer = Trainer.resume_from_path(model_path, new_exp=None, inference=True)
292
+
293
+ # find paths for all images in the input folder. There is a recursive option.
294
+ im_paths = sorted(find_images(input_folder), key=lambda x: x.name)
295
+
296
+ # Load images into tensors
297
+ # * smaller side resized to 640 - keeping aspect ratio
298
+ # * then longer side is cropped in the center
299
+ # * result is a 1x3x640x640 float tensor in [-1; 1]
300
+ xs = PrepareInference()(im_paths)
301
+
302
+ # send to device
303
+ xs = [x.to(trainer.device) for x in xs]
304
+
305
+ # compute flood
306
+ # * compute mask
307
+ # * binarize mask if bin_value > 0
308
+ # * paint x using this mask
309
+ ys = [trainer.compute_flood(x, bin_value=0.5) for x in tqdm(xs)]
310
+
311
+ # convert 1x3x640x640 float tensors in [-1; 1] into 640x640x3 numpy arrays in [0, 255]
312
+ np_ys = [tensor_ims_to_np_uint8s(y) for y in tqdm(ys)]
313
+
314
+ # write images
315
+ for i, n in tqdm(zip(im_paths, np_ys), total=len(im_paths)):
316
+ imsave(Path(output_path) / i.name, n)
317
+ ```
318
+
319
+ ## Release process
320
+
321
+ In the `release/` folder
322
+ * create a `model/` folder
323
+ * create folders `model/masker/` and `model/painter/`
324
+ * add the climategan code in `release/`: `git clone git@github.com:cc-ai/climategan.git`
325
+ * move the code to `release/`: `cp climategan/* . && rm -rf climategan`
326
+ * update `model/masker/opts/events` with `events:` from `shared/trainer/opts.yaml`
327
+ * update `model/masker/opts/val.val_painter` to `"model/painter/checkpoints/latest_ckpt.pth"`
328
+ * update `model/masker/opts/load_paths.m` to `"model/masker/checkpoints/latest_ckpt.pth"`
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/app.py # noqa: E501
2
+ # thank you @NimaBoscarino
3
+
4
+ import os
5
+ import gradio as gr
6
+ import googlemaps
7
+ from skimage import io
8
+ from urllib import parse
9
+ from inferences import ClimateGAN
10
+
11
+
12
+ def predict(api_key):
13
+ def _predict(*args):
14
+ print("args: ", args)
15
+ image = place = None
16
+ if len(args) == 1:
17
+ image = args[0]
18
+ else:
19
+ assert len(args) == 2, "Unknown number of inputs {}".format(len(args))
20
+ image, place = args
21
+
22
+ if api_key and place:
23
+ geocode_result = gmaps.geocode(place)
24
+
25
+ address = geocode_result[0]["formatted_address"]
26
+ static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}"
27
+ img_np = io.imread(static_map_url)
28
+ else:
29
+ img_np = image
30
+ flood, wildfire, smog = model.inference(img_np)
31
+ return img_np, flood, wildfire, smog
32
+
33
+ return _predict
34
+
35
+
36
+ if __name__ == "__main__":
37
+
38
+ api_key = os.environ.get("GMAPS_API_KEY")
39
+ gmaps = None
40
+ if api_key is not None:
41
+ gmaps = googlemaps.Client(key=api_key)
42
+
43
+ model = ClimateGAN(model_path="config/model/masker")
44
+
45
+ inputs = inputs = [gr.inputs.Image(label="Input Image")]
46
+ if api_key:
47
+ inputs += [gr.inputs.Textbox(label="Address or place name")]
48
+
49
+ gr.Interface(
50
+ predict(api_key),
51
+ inputs=[
52
+ gr.inputs.Textbox(label="Address or place name"),
53
+ gr.inputs.Image(label="Input Image"),
54
+ ],
55
+ outputs=[
56
+ gr.outputs.Image(type="numpy", label="Original image"),
57
+ gr.outputs.Image(type="numpy", label="Flooding"),
58
+ gr.outputs.Image(type="numpy", label="Wildfire"),
59
+ gr.outputs.Image(type="numpy", label="Smog"),
60
+ ],
61
+ title="ClimateGAN: Visualize Climate Change",
62
+ description='Climate change does not impact everyone equally. This Space shows the effects of the climate emergency, "one address at a time". Visit the original experience at <a href="https://thisclimatedoesnotexist.com/">ThisClimateDoesNotExist.com</a>.<br>Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.', # noqa: E501
63
+ article="<p style='text-align: center'>This project is an unofficial clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>", # noqa: E501
64
+ # examples=[
65
+ # "Vancouver Art Gallery",
66
+ # "Chicago Bean",
67
+ # "Duomo Siracusa",
68
+ # ],
69
+ css=".footer{display:none !important}",
70
+ ).launch()
apply_events.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def parse_args():
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument(
7
+ "-b",
8
+ "--batch_size",
9
+ type=int,
10
+ default=4,
11
+ help="Batch size to process input images to events. Defaults to 4",
12
+ )
13
+ parser.add_argument(
14
+ "-i",
15
+ "--images_paths",
16
+ type=str,
17
+ required=True,
18
+ help="Path to a directory with image files",
19
+ )
20
+ parser.add_argument(
21
+ "-o",
22
+ "--output_path",
23
+ type=str,
24
+ default=None,
25
+ help="Path to a directory were events should be written. "
26
+ + "Will NOT write anything to disk if this flag is not used.",
27
+ )
28
+ parser.add_argument(
29
+ "-s",
30
+ "--save_input",
31
+ action="store_true",
32
+ default=False,
33
+ help="Binary flag to include the input image to the model (after crop and"
34
+ + " resize) in the images written or uploaded (depending on saving options.)",
35
+ )
36
+ parser.add_argument(
37
+ "-r",
38
+ "--resume_path",
39
+ type=str,
40
+ default=None,
41
+ help="Path to a directory containing the trainer to resume."
42
+ + " In particular it must contain `opts.yam` and `checkpoints/`."
43
+ + " Typically this points to a Masker, which holds the path to a"
44
+ + " Painter in its opts",
45
+ )
46
+ parser.add_argument(
47
+ "--no_time",
48
+ action="store_true",
49
+ default=False,
50
+ help="Binary flag to prevent the timing of operations.",
51
+ )
52
+ parser.add_argument(
53
+ "-f",
54
+ "--flood_mask_binarization",
55
+ type=float,
56
+ default=0.5,
57
+ help="Value to use to binarize masks (mask > value). "
58
+ + "Set to -1 to use soft masks (not binarized). Defaults to 0.5.",
59
+ )
60
+ parser.add_argument(
61
+ "-t",
62
+ "--target_size",
63
+ type=int,
64
+ default=640,
65
+ help="Output image size (when not using `keep_ratio_128`): images are resized"
66
+ + " such that their smallest side is `target_size` then cropped in the middle"
67
+ + " of the largest side such that the resulting input image (and output images)"
68
+ + " has height and width `target_size x target_size`. **Must** be a multiple of"
69
+ + " 2^7=128 (up/downscaling inside the models). Defaults to 640.",
70
+ )
71
+ parser.add_argument(
72
+ "--half",
73
+ action="store_true",
74
+ default=False,
75
+ help="Binary flag to use half precision (float16). Defaults to False.",
76
+ )
77
+ parser.add_argument(
78
+ "-n",
79
+ "--n_images",
80
+ default=-1,
81
+ type=int,
82
+ help="Limit the number of images processed (if you have 100 images in "
83
+ + "a directory but n is 10 then only the first 10 images will be loaded"
84
+ + " for processing)",
85
+ )
86
+ parser.add_argument(
87
+ "--no_conf",
88
+ action="store_true",
89
+ default=False,
90
+ help="disable writing the apply_events hash and command in the output folder",
91
+ )
92
+ parser.add_argument(
93
+ "--overwrite",
94
+ action="store_true",
95
+ default=False,
96
+ help="Do not check for existing outdir, i.e. force overwrite"
97
+ + " potentially existing files in the output path",
98
+ )
99
+ parser.add_argument(
100
+ "--no_cloudy",
101
+ action="store_true",
102
+ default=False,
103
+ help="Prevent the use of the cloudy intermediate"
104
+ + " image to create the flood image. Rendering will"
105
+ + " be more colorful but may seem less realistic",
106
+ )
107
+ parser.add_argument(
108
+ "--keep_ratio_128",
109
+ action="store_true",
110
+ default=False,
111
+ help="When loading the input images, resize and crop them in order for their "
112
+ + "dimensions to match the closest multiples"
113
+ + " of 128. Will force a batch size of 1 since images"
114
+ + " now have different dimensions. "
115
+ + "Use --max_im_width to cap the resulting dimensions.",
116
+ )
117
+ parser.add_argument(
118
+ "--fuse",
119
+ action="store_true",
120
+ default=False,
121
+ help="Use batch norm fusion to speed up inference",
122
+ )
123
+ parser.add_argument(
124
+ "--save_masks",
125
+ action="store_true",
126
+ default=False,
127
+ help="Save output masks along events",
128
+ )
129
+ parser.add_argument(
130
+ "-m",
131
+ "--max_im_width",
132
+ type=int,
133
+ default=-1,
134
+ help="When using --keep_ratio_128, some images may still be too large. Use "
135
+ + "--max_im_width to cap the resized image's width. Defaults to -1 (no cap).",
136
+ )
137
+ parser.add_argument(
138
+ "--upload",
139
+ action="store_true",
140
+ help="Upload to comet.ml in a project called `climategan-apply`",
141
+ )
142
+ parser.add_argument(
143
+ "--zip_outdir",
144
+ "-z",
145
+ action="store_true",
146
+ help="Zip the output directory as '{outdir.parent}/{outdir.name}.zip'",
147
+ )
148
+ return parser.parse_args()
149
+
150
+
151
+ args = parse_args()
152
+
153
+
154
+ print("\n• Imports\n")
155
+ import time
156
+
157
+ import_time = time.time()
158
+ import sys
159
+ import shutil
160
+ from collections import OrderedDict
161
+ from pathlib import Path
162
+
163
+ import comet_ml # noqa: F401
164
+ import torch
165
+ import numpy as np
166
+ import skimage.io as io
167
+ from skimage.color import rgba2rgb
168
+ from skimage.transform import resize
169
+ from tqdm import tqdm
170
+
171
+ from climategan.trainer import Trainer
172
+ from climategan.bn_fusion import bn_fuse
173
+ from climategan.tutils import print_num_parameters
174
+ from climategan.utils import Timer, find_images, get_git_revision_hash, to_128, resolve
175
+
176
+ import_time = time.time() - import_time
177
+
178
+
179
+ def to_m1_p1(img, i):
180
+ """
181
+ rescales a [0, 1] image to [-1, +1]
182
+
183
+ Args:
184
+ img (np.array): float32 numpy array of an image in [0, 1]
185
+ i (int): Index of the image being rescaled
186
+
187
+ Raises:
188
+ ValueError: If the image is not in [0, 1]
189
+
190
+ Returns:
191
+ np.array(np.float32): array in [-1, +1]
192
+ """
193
+ if img.min() >= 0 and img.max() <= 1:
194
+ return (img.astype(np.float32) - 0.5) * 2
195
+ raise ValueError(f"Data range mismatch for image {i} : ({img.min()}, {img.max()})")
196
+
197
+
198
+ def uint8(array):
199
+ """
200
+ convert an array to np.uint8 (does not rescale or anything else than changing dtype)
201
+
202
+ Args:
203
+ array (np.array): array to modify
204
+
205
+ Returns:
206
+ np.array(np.uint8): converted array
207
+ """
208
+ return array.astype(np.uint8)
209
+
210
+
211
+ def resize_and_crop(img, to=640):
212
+ """
213
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
214
+ is `to`, then crops this resized image in its center so that the output is `to x to`
215
+ without aspect ratio distortion
216
+
217
+ Args:
218
+ img (np.array): np.uint8 255 image
219
+
220
+ Returns:
221
+ np.array: [0, 1] np.float32 image
222
+ """
223
+ # resize keeping aspect ratio: smallest dim is 640
224
+ h, w = img.shape[:2]
225
+ if h < w:
226
+ size = (to, int(to * w / h))
227
+ else:
228
+ size = (int(to * h / w), to)
229
+
230
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
231
+ r_img = uint8(r_img)
232
+
233
+ # crop in the center
234
+ H, W = r_img.shape[:2]
235
+
236
+ top = (H - to) // 2
237
+ left = (W - to) // 2
238
+
239
+ rc_img = r_img[top : top + to, left : left + to, :]
240
+
241
+ return rc_img / 255.0
242
+
243
+
244
+ def print_time(text, time_series, purge=-1):
245
+ """
246
+ Print a timeseries's mean and std with a label
247
+
248
+ Args:
249
+ text (str): label of the time series
250
+ time_series (list): list of timings
251
+ purge (int, optional): ignore first n values of time series. Defaults to -1.
252
+ """
253
+ if not time_series:
254
+ return
255
+
256
+ if purge > 0 and len(time_series) > purge:
257
+ time_series = time_series[purge:]
258
+
259
+ m = np.mean(time_series)
260
+ s = np.std(time_series)
261
+
262
+ print(
263
+ f"{text.capitalize() + ' ':.<26} {m:.5f}"
264
+ + (f" +/- {s:.5f}" if len(time_series) > 1 else "")
265
+ )
266
+
267
+
268
+ def print_store(store, purge=-1):
269
+ """
270
+ Pretty-print time series store
271
+
272
+ Args:
273
+ store (dict): maps string keys to lists of times
274
+ purge (int, optional): ignore first n values of time series. Defaults to -1.
275
+ """
276
+ singles = OrderedDict({k: v for k, v in store.items() if len(v) == 1})
277
+ multiples = OrderedDict({k: v for k, v in store.items() if len(v) > 1})
278
+ empties = {k: v for k, v in store.items() if len(v) == 0}
279
+
280
+ if empties:
281
+ print("Ignoring empty stores ", ", ".join(empties.keys()))
282
+ print()
283
+
284
+ for k in singles:
285
+ print_time(k, singles[k], purge)
286
+
287
+ print()
288
+ print("Unit: s/batch")
289
+ for k in multiples:
290
+ print_time(k, multiples[k], purge)
291
+ print()
292
+
293
+
294
+ def write_apply_config(out):
295
+ """
296
+ Saves the args to `apply_events.py` in a text file for future reference
297
+ """
298
+ cwd = Path.cwd().expanduser().resolve()
299
+ command = f"cd {str(cwd)}\n"
300
+ command += " ".join(sys.argv)
301
+ git_hash = get_git_revision_hash()
302
+ with (out / "command.txt").open("w") as f:
303
+ f.write(command)
304
+ with (out / "hash.txt").open("w") as f:
305
+ f.write(git_hash)
306
+
307
+
308
+ def get_outdir_name(half, keep_ratio, max_im_width, target_size, bin_value, cloudy):
309
+ """
310
+ Create the output directory's name based on uer-provided arguments
311
+ """
312
+ name_items = []
313
+ if half:
314
+ name_items.append("half")
315
+ if keep_ratio:
316
+ name_items.append("AR")
317
+ if max_im_width and keep_ratio:
318
+ name_items.append(f"{max_im_width}")
319
+ if target_size and not keep_ratio:
320
+ name_items.append("S")
321
+ name_items.append(f"{target_size}")
322
+ if bin_value != 0.5:
323
+ name_items.append(f"bin{bin_value}")
324
+ if not cloudy:
325
+ name_items.append("no_cloudy")
326
+
327
+ return "-".join(name_items)
328
+
329
+
330
+ def make_outdir(
331
+ outdir, overwrite, half, keep_ratio, max_im_width, target_size, bin_value, cloudy
332
+ ):
333
+ """
334
+ Creates the output directory if it does not exist. If it does exist,
335
+ prompts the user for confirmation (except if `overwrite` is True).
336
+ If the output directory's name is "_auto_" then it is created as:
337
+ outdir.parent / get_outdir_name(...)
338
+ """
339
+ if outdir.name == "_auto_":
340
+ outdir = outdir.parent / get_outdir_name(
341
+ half, keep_ratio, max_im_width, target_size, bin_value, cloudy
342
+ )
343
+ if outdir.exists() and not overwrite:
344
+ print(
345
+ f"\nWARNING: outdir ({str(outdir)}) already exists."
346
+ + " Files with existing names will be overwritten"
347
+ )
348
+ if "n" in input(">>> Continue anyway? [y / n] (default: y) : "):
349
+ print("Interrupting execution from user input.")
350
+ sys.exit()
351
+ print()
352
+ outdir.mkdir(exist_ok=True, parents=True)
353
+ return outdir
354
+
355
+
356
+ def get_time_stores(import_time):
357
+ return OrderedDict(
358
+ {
359
+ "imports": [import_time],
360
+ "setup": [],
361
+ "data pre-processing": [],
362
+ "encode": [],
363
+ "mask": [],
364
+ "flood": [],
365
+ "depth": [],
366
+ "segmentation": [],
367
+ "smog": [],
368
+ "wildfire": [],
369
+ "all events": [],
370
+ "numpy": [],
371
+ "inference on all images": [],
372
+ "write": [],
373
+ }
374
+ )
375
+
376
+
377
+ if __name__ == "__main__":
378
+
379
+ # -----------------------------------------
380
+ # ----- Initialize script variables -----
381
+ # -----------------------------------------
382
+ print(
383
+ "• Using args\n\n"
384
+ + "\n".join(["{:25}: {}".format(k, v) for k, v in vars(args).items()]),
385
+ )
386
+
387
+ batch_size = args.batch_size
388
+ bin_value = args.flood_mask_binarization
389
+ cloudy = not args.no_cloudy
390
+ fuse = args.fuse
391
+ half = args.half
392
+ save_masks = args.save_masks
393
+ images_paths = resolve(args.images_paths)
394
+ keep_ratio = args.keep_ratio_128
395
+ max_im_width = args.max_im_width
396
+ n_images = args.n_images
397
+ outdir = resolve(args.output_path) if args.output_path is not None else None
398
+ resume_path = args.resume_path
399
+ target_size = args.target_size
400
+ time_inference = not args.no_time
401
+ upload = args.upload
402
+ zip_outdir = args.zip_outdir
403
+
404
+ # -------------------------------------
405
+ # ----- Validate size arguments -----
406
+ # -------------------------------------
407
+ if keep_ratio:
408
+ if target_size != 640:
409
+ print(
410
+ "\nWARNING: using --keep_ratio_128 overwrites target_size"
411
+ + " which is ignored."
412
+ )
413
+ if batch_size != 1:
414
+ print("\nWARNING: batch_size overwritten to 1 when using keep_ratio_128")
415
+ batch_size = 1
416
+ if max_im_width > 0 and max_im_width % 128 != 0:
417
+ new_im_width = int(max_im_width / 128) * 128
418
+ print("\nWARNING: max_im_width should be <0 or a multiple of 128.")
419
+ print(
420
+ " Was {} but is now overwritten to {}".format(
421
+ max_im_width, new_im_width
422
+ )
423
+ )
424
+ max_im_width = new_im_width
425
+ else:
426
+ if target_size % 128 != 0:
427
+ print(f"\nWarning: target size {target_size} is not a multiple of 128.")
428
+ target_size = target_size - (target_size % 128)
429
+ print(f"Setting target_size to {target_size}.")
430
+
431
+ # -------------------------------------
432
+ # ----- Create output directory -----
433
+ # -------------------------------------
434
+ if outdir is not None:
435
+ outdir = make_outdir(
436
+ outdir,
437
+ args.overwrite,
438
+ half,
439
+ keep_ratio,
440
+ max_im_width,
441
+ target_size,
442
+ bin_value,
443
+ cloudy,
444
+ )
445
+
446
+ # -------------------------------
447
+ # ----- Create time store -----
448
+ # -------------------------------
449
+ stores = get_time_stores(import_time)
450
+
451
+ # -----------------------------------
452
+ # ----- Load Trainer instance -----
453
+ # -----------------------------------
454
+ with Timer(store=stores.get("setup", []), ignore=time_inference):
455
+ print("\n• Initializing trainer\n")
456
+ torch.set_grad_enabled(False)
457
+ trainer = Trainer.resume_from_path(
458
+ resume_path,
459
+ setup=True,
460
+ inference=True,
461
+ new_exp=None,
462
+ )
463
+ print()
464
+ print_num_parameters(trainer, True)
465
+ if fuse:
466
+ trainer.G = bn_fuse(trainer.G)
467
+ if half:
468
+ trainer.G.half()
469
+
470
+ # --------------------------------------------
471
+ # ----- Read data from input directory -----
472
+ # --------------------------------------------
473
+ print("\n• Reading & Pre-processing Data\n")
474
+
475
+ # find all images
476
+ data_paths = find_images(images_paths)
477
+ base_data_paths = data_paths
478
+ # filter images
479
+ if 0 < n_images < len(data_paths):
480
+ data_paths = data_paths[:n_images]
481
+ # repeat data
482
+ elif n_images > len(data_paths):
483
+ repeats = n_images // len(data_paths) + 1
484
+ data_paths = base_data_paths * repeats
485
+ data_paths = data_paths[:n_images]
486
+
487
+ with Timer(store=stores.get("data pre-processing", []), ignore=time_inference):
488
+ # read images to numpy arrays
489
+ data = [io.imread(str(d)) for d in data_paths]
490
+ # rgba to rgb
491
+ data = [im if im.shape[-1] == 3 else uint8(rgba2rgb(im) * 255) for im in data]
492
+ # resize images to target_size or
493
+ if keep_ratio:
494
+ # to closest multiples of 128 <= max_im_width, keeping aspect ratio
495
+ new_sizes = [to_128(d, max_im_width) for d in data]
496
+ data = [resize(d, ns, anti_aliasing=True) for d, ns in zip(data, new_sizes)]
497
+ else:
498
+ # to args.target_size
499
+ data = [resize_and_crop(d, target_size) for d in data]
500
+ new_sizes = [(target_size, target_size) for _ in data]
501
+ # resize() produces [0, 1] images, rescale to [-1, 1]
502
+ data = [to_m1_p1(d, i) for i, d in enumerate(data)]
503
+
504
+ n_batchs = len(data) // batch_size
505
+ if len(data) % batch_size != 0:
506
+ n_batchs += 1
507
+
508
+ print("Found", len(base_data_paths), "images. Inferring on", len(data), "images.")
509
+
510
+ # --------------------------------------------
511
+ # ----- Batch-process images to events -----
512
+ # --------------------------------------------
513
+ print(f"\n• Using device {str(trainer.device)}\n")
514
+
515
+ all_events = []
516
+
517
+ with Timer(store=stores.get("inference on all images", []), ignore=time_inference):
518
+ for b in tqdm(range(n_batchs), desc="Infering events", unit="batch"):
519
+
520
+ images = data[b * batch_size : (b + 1) * batch_size]
521
+ if not images:
522
+ continue
523
+
524
+ # concatenate images in a batch batch_size x height x width x 3
525
+ images = np.stack(images)
526
+ # Retreive numpy events as a dict {event: array[BxHxWxC]}
527
+ events = trainer.infer_all(
528
+ images,
529
+ numpy=True,
530
+ stores=stores,
531
+ bin_value=bin_value,
532
+ half=half,
533
+ cloudy=cloudy,
534
+ return_masks=save_masks,
535
+ )
536
+
537
+ # save resized and cropped image
538
+ if args.save_input:
539
+ events["input"] = uint8((images + 1) / 2 * 255)
540
+
541
+ # store events to write after inference loop
542
+ all_events.append(events)
543
+
544
+ # --------------------------------------------
545
+ # ----- Save (write/upload) inferences -----
546
+ # --------------------------------------------
547
+ if outdir is not None or upload:
548
+
549
+ if upload:
550
+ print("\n• Creating comet Experiment")
551
+ exp = comet_ml.Experiment(project_name="climategan-apply")
552
+ exp.log_parameters(vars(args))
553
+
554
+ # --------------------------------------------------------------
555
+ # ----- Change inferred data structure to a list of dicts -----
556
+ # --------------------------------------------------------------
557
+ to_write = []
558
+ events_names = list(all_events[0].keys())
559
+ for events_data in all_events:
560
+ n_ims = len(events_data[events_names[0]])
561
+ for i in range(n_ims):
562
+ item = {event: events_data[event][i] for event in events_names}
563
+ to_write.append(item)
564
+
565
+ progress_bar_desc = ""
566
+ if outdir is not None:
567
+ print("\n• Output directory:\n")
568
+ print(str(outdir), "\n")
569
+ if upload:
570
+ progress_bar_desc = "Writing & Uploading events"
571
+ else:
572
+ progress_bar_desc = "Writing events"
573
+ else:
574
+ if upload:
575
+ progress_bar_desc = "Uploading events"
576
+
577
+ # ------------------------------------
578
+ # ----- Save individual images -----
579
+ # ------------------------------------
580
+ with Timer(store=stores.get("write", []), ignore=time_inference):
581
+
582
+ # for each image
583
+ for t, event_dict in tqdm(
584
+ enumerate(to_write),
585
+ desc=progress_bar_desc,
586
+ unit="input image",
587
+ total=len(to_write),
588
+ ):
589
+
590
+ idx = t % len(base_data_paths)
591
+ stem = Path(data_paths[idx]).stem
592
+ width = new_sizes[idx][1]
593
+
594
+ if keep_ratio:
595
+ ar = "_AR"
596
+ else:
597
+ ar = ""
598
+
599
+ # for each event type
600
+ event_bar = tqdm(
601
+ enumerate(event_dict.items()),
602
+ leave=False,
603
+ total=len(events_names),
604
+ unit="event",
605
+ )
606
+ for e, (event, im_data) in event_bar:
607
+ event_bar.set_description(
608
+ f" {event.capitalize():<{len(progress_bar_desc) - 2}}"
609
+ )
610
+
611
+ if args.no_cloudy:
612
+ suffix = ar + "_no_cloudy"
613
+ else:
614
+ suffix = ar
615
+
616
+ im_path = Path(f"{stem}_{event}_{width}{suffix}.png")
617
+
618
+ if outdir is not None:
619
+ im_path = outdir / im_path
620
+ io.imsave(im_path, im_data)
621
+
622
+ if upload:
623
+ exp.log_image(im_data, name=im_path.name)
624
+ if zip_outdir:
625
+ print("\n• Zipping output directory... ", end="", flush=True)
626
+ archive_path = Path(shutil.make_archive(outdir.name, "zip", root_dir=outdir))
627
+ archive_path = archive_path.rename(outdir.parent / archive_path.name)
628
+ print("Done:\n")
629
+ print(str(archive_path))
630
+
631
+ # ---------------------------
632
+ # ----- Print timings -----
633
+ # ---------------------------
634
+ if time_inference:
635
+ print("\n• Timings\n")
636
+ print_store(stores)
637
+
638
+ # ---------------------------------------------
639
+ # ----- Save apply_events.py run config -----
640
+ # ---------------------------------------------
641
+ if not args.no_conf and outdir is not None:
642
+ write_apply_config(outdir)
climategan/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from pathlib import Path
3
+
4
+ __all__ = [
5
+ import_module(f".{f.stem}", __package__)
6
+ for f in Path(__file__).parent.glob("*.py")
7
+ if "__" not in f.stem
8
+ ]
9
+ del import_module, Path
climategan/blocks.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """File for all blocks which are parts of decoders
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import climategan.strings as strings
8
+ from climategan.norms import SPADE, AdaptiveInstanceNorm2d, LayerNorm, SpectralNorm
9
+
10
+
11
+ class InterpolateNearest2d(nn.Module):
12
+ """
13
+ Custom implementation of nn.Upsample because pytorch/xla
14
+ does not yet support scale_factor and needs to be provided with
15
+ the output_size
16
+ """
17
+
18
+ def __init__(self, scale_factor=2):
19
+ """
20
+ Create an InterpolateNearest2d module
21
+
22
+ Args:
23
+ scale_factor (int, optional): Output size multiplier. Defaults to 2.
24
+ """
25
+ super().__init__()
26
+ self.scale_factor = scale_factor
27
+
28
+ def forward(self, x):
29
+ """
30
+ Interpolate x in "nearest" mode on its last 2 dimensions
31
+
32
+ Args:
33
+ x (torch.Tensor): input to interpolate
34
+
35
+ Returns:
36
+ torch.Tensor: upsampled tensor with shape
37
+ (...x.shape, x.shape[-2] * scale_factor, x.shape[-1] * scale_factor)
38
+ """
39
+ return F.interpolate(
40
+ x,
41
+ size=(x.shape[-2] * self.scale_factor, x.shape[-1] * self.scale_factor),
42
+ mode="nearest",
43
+ )
44
+
45
+
46
+ # -----------------------------------------
47
+ # ----- Generic Convolutional Block -----
48
+ # -----------------------------------------
49
+ class Conv2dBlock(nn.Module):
50
+ def __init__(
51
+ self,
52
+ input_dim,
53
+ output_dim,
54
+ kernel_size,
55
+ stride=1,
56
+ padding=0,
57
+ dilation=1,
58
+ norm="none",
59
+ activation="relu",
60
+ pad_type="zero",
61
+ bias=True,
62
+ ):
63
+ super().__init__()
64
+ self.use_bias = bias
65
+ # initialize padding
66
+ if pad_type == "reflect":
67
+ self.pad = nn.ReflectionPad2d(padding)
68
+ elif pad_type == "replicate":
69
+ self.pad = nn.ReplicationPad2d(padding)
70
+ elif pad_type == "zero":
71
+ self.pad = nn.ZeroPad2d(padding)
72
+ else:
73
+ assert 0, "Unsupported padding type: {}".format(pad_type)
74
+
75
+ # initialize normalization
76
+ use_spectral_norm = False
77
+ if norm.startswith("spectral_"):
78
+ norm = norm.replace("spectral_", "")
79
+ use_spectral_norm = True
80
+
81
+ norm_dim = output_dim
82
+ if norm == "batch":
83
+ self.norm = nn.BatchNorm2d(norm_dim)
84
+ elif norm == "instance":
85
+ # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
86
+ self.norm = nn.InstanceNorm2d(norm_dim)
87
+ elif norm == "layer":
88
+ self.norm = LayerNorm(norm_dim)
89
+ elif norm == "adain":
90
+ self.norm = AdaptiveInstanceNorm2d(norm_dim)
91
+ elif norm == "spectral" or norm.startswith("spectral_"):
92
+ self.norm = None # dealt with later in the code
93
+ elif norm == "none":
94
+ self.norm = None
95
+ else:
96
+ raise ValueError("Unsupported normalization: {}".format(norm))
97
+
98
+ # initialize activation
99
+ if activation == "relu":
100
+ self.activation = nn.ReLU(inplace=False)
101
+ elif activation == "lrelu":
102
+ self.activation = nn.LeakyReLU(0.2, inplace=False)
103
+ elif activation == "prelu":
104
+ self.activation = nn.PReLU()
105
+ elif activation == "selu":
106
+ self.activation = nn.SELU(inplace=False)
107
+ elif activation == "tanh":
108
+ self.activation = nn.Tanh()
109
+ elif activation == "sigmoid":
110
+ self.activation = nn.Sigmoid()
111
+ elif activation == "none":
112
+ self.activation = None
113
+ else:
114
+ raise ValueError("Unsupported activation: {}".format(activation))
115
+
116
+ # initialize convolution
117
+ if norm == "spectral" or use_spectral_norm:
118
+ self.conv = SpectralNorm(
119
+ nn.Conv2d(
120
+ input_dim,
121
+ output_dim,
122
+ kernel_size,
123
+ stride,
124
+ dilation=dilation,
125
+ bias=self.use_bias,
126
+ )
127
+ )
128
+ else:
129
+ self.conv = nn.Conv2d(
130
+ input_dim,
131
+ output_dim,
132
+ kernel_size,
133
+ stride,
134
+ dilation=dilation,
135
+ bias=self.use_bias if norm != "batch" else False,
136
+ )
137
+
138
+ def forward(self, x):
139
+ x = self.conv(self.pad(x))
140
+ if self.norm is not None:
141
+ x = self.norm(x)
142
+ if self.activation is not None:
143
+ x = self.activation(x)
144
+ return x
145
+
146
+ def __str__(self):
147
+ return strings.conv2dblock(self)
148
+
149
+
150
+ # -----------------------------
151
+ # ----- Residual Blocks -----
152
+ # -----------------------------
153
+ class ResBlocks(nn.Module):
154
+ """
155
+ From https://github.com/NVlabs/MUNIT/blob/master/networks.py
156
+ """
157
+
158
+ def __init__(self, num_blocks, dim, norm="in", activation="relu", pad_type="zero"):
159
+ super().__init__()
160
+ self.model = nn.Sequential(
161
+ *[
162
+ ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)
163
+ for _ in range(num_blocks)
164
+ ]
165
+ )
166
+
167
+ def forward(self, x):
168
+ return self.model(x)
169
+
170
+ def __str__(self):
171
+ return strings.resblocks(self)
172
+
173
+
174
+ class ResBlock(nn.Module):
175
+ def __init__(self, dim, norm="in", activation="relu", pad_type="zero"):
176
+ super().__init__()
177
+ self.dim = dim
178
+ self.norm = norm
179
+ self.activation = activation
180
+ model = []
181
+ model += [
182
+ Conv2dBlock(
183
+ dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type
184
+ )
185
+ ]
186
+ model += [
187
+ Conv2dBlock(
188
+ dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type
189
+ )
190
+ ]
191
+ self.model = nn.Sequential(*model)
192
+
193
+ def forward(self, x):
194
+ residual = x
195
+ out = self.model(x)
196
+ out += residual
197
+ return out
198
+
199
+ def __str__(self):
200
+ return strings.resblock(self)
201
+
202
+
203
+ # --------------------------
204
+ # ----- Base Decoder -----
205
+ # --------------------------
206
+ class BaseDecoder(nn.Module):
207
+ def __init__(
208
+ self,
209
+ n_upsample=4,
210
+ n_res=4,
211
+ input_dim=2048,
212
+ proj_dim=64,
213
+ output_dim=3,
214
+ norm="batch",
215
+ activ="relu",
216
+ pad_type="zero",
217
+ output_activ="tanh",
218
+ low_level_feats_dim=-1,
219
+ use_dada=False,
220
+ ):
221
+ super().__init__()
222
+
223
+ self.low_level_feats_dim = low_level_feats_dim
224
+ self.use_dada = use_dada
225
+
226
+ self.model = []
227
+ if proj_dim != -1:
228
+ self.proj_conv = Conv2dBlock(
229
+ input_dim, proj_dim, 1, 1, 0, norm=norm, activation=activ
230
+ )
231
+ else:
232
+ self.proj_conv = None
233
+ proj_dim = input_dim
234
+
235
+ if low_level_feats_dim > 0:
236
+ self.low_level_conv = Conv2dBlock(
237
+ input_dim=low_level_feats_dim,
238
+ output_dim=proj_dim,
239
+ kernel_size=3,
240
+ stride=1,
241
+ padding=1,
242
+ pad_type=pad_type,
243
+ norm=norm,
244
+ activation=activ,
245
+ )
246
+ self.merge_feats_conv = Conv2dBlock(
247
+ input_dim=2 * proj_dim,
248
+ output_dim=proj_dim,
249
+ kernel_size=1,
250
+ stride=1,
251
+ padding=0,
252
+ pad_type=pad_type,
253
+ norm=norm,
254
+ activation=activ,
255
+ )
256
+ else:
257
+ self.low_level_conv = None
258
+
259
+ self.model += [ResBlocks(n_res, proj_dim, norm, activ, pad_type=pad_type)]
260
+ dim = proj_dim
261
+ # upsampling blocks
262
+ for i in range(n_upsample):
263
+ self.model += [
264
+ InterpolateNearest2d(scale_factor=2),
265
+ Conv2dBlock(
266
+ input_dim=dim,
267
+ output_dim=dim // 2,
268
+ kernel_size=3,
269
+ stride=1,
270
+ padding=1,
271
+ pad_type=pad_type,
272
+ norm=norm,
273
+ activation=activ,
274
+ ),
275
+ ]
276
+ dim //= 2
277
+ # use reflection padding in the last conv layer
278
+ self.model += [
279
+ Conv2dBlock(
280
+ input_dim=dim,
281
+ output_dim=output_dim,
282
+ kernel_size=3,
283
+ stride=1,
284
+ padding=1,
285
+ pad_type=pad_type,
286
+ norm="none",
287
+ activation=output_activ,
288
+ )
289
+ ]
290
+ self.model = nn.Sequential(*self.model)
291
+
292
+ def forward(self, z, cond=None, z_depth=None):
293
+ low_level_feat = None
294
+ if isinstance(z, (list, tuple)):
295
+ if self.low_level_conv is None:
296
+ z = z[0]
297
+ else:
298
+ z, low_level_feat = z
299
+ low_level_feat = self.low_level_conv(low_level_feat)
300
+ low_level_feat = F.interpolate(
301
+ low_level_feat, size=z.shape[-2:], mode="bilinear"
302
+ )
303
+
304
+ if z_depth is not None and self.use_dada:
305
+ z = z * z_depth
306
+
307
+ if self.proj_conv is not None:
308
+ z = self.proj_conv(z)
309
+
310
+ if low_level_feat is not None:
311
+ z = self.merge_feats_conv(torch.cat([low_level_feat, z], dim=1))
312
+
313
+ return self.model(z)
314
+
315
+ def __str__(self):
316
+ return strings.basedecoder(self)
317
+
318
+
319
+ # --------------------------
320
+ # ----- SPADE Blocks -----
321
+ # --------------------------
322
+ # https://github.com/NVlabs/SPADE/blob/0ff661e70131c9b85091d11a66e019c0f2062d4c
323
+ # /models/networks/generator.py
324
+ # 0ff661e on 13 Apr 2019
325
+ class SPADEResnetBlock(nn.Module):
326
+ def __init__(
327
+ self,
328
+ fin,
329
+ fout,
330
+ cond_nc,
331
+ spade_use_spectral_norm,
332
+ spade_param_free_norm,
333
+ spade_kernel_size,
334
+ last_activation=None,
335
+ ):
336
+ super().__init__()
337
+ # Attributes
338
+
339
+ self.fin = fin
340
+ self.fout = fout
341
+ self.use_spectral_norm = spade_use_spectral_norm
342
+ self.param_free_norm = spade_param_free_norm
343
+ self.kernel_size = spade_kernel_size
344
+
345
+ self.learned_shortcut = fin != fout
346
+ self.last_activation = last_activation
347
+ fmiddle = min(fin, fout)
348
+
349
+ # create conv layers
350
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
351
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
352
+ if self.learned_shortcut:
353
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
354
+
355
+ # apply spectral norm if specified
356
+ if spade_use_spectral_norm:
357
+ self.conv_0 = SpectralNorm(self.conv_0)
358
+ self.conv_1 = SpectralNorm(self.conv_1)
359
+ if self.learned_shortcut:
360
+ self.conv_s = SpectralNorm(self.conv_s)
361
+
362
+ self.norm_0 = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc)
363
+ self.norm_1 = SPADE(spade_param_free_norm, spade_kernel_size, fmiddle, cond_nc)
364
+ if self.learned_shortcut:
365
+ self.norm_s = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc)
366
+
367
+ # note the resnet block with SPADE also takes in |seg|,
368
+ # the semantic segmentation map as input
369
+ def forward(self, x, seg):
370
+ x_s = self.shortcut(x, seg)
371
+
372
+ dx = self.conv_0(self.activation(self.norm_0(x, seg)))
373
+ dx = self.conv_1(self.activation(self.norm_1(dx, seg)))
374
+
375
+ out = x_s + dx
376
+ if self.last_activation == "lrelu":
377
+ return self.activation(out)
378
+ elif self.last_activation is None:
379
+ return out
380
+ else:
381
+ raise NotImplementedError(
382
+ "The type of activation is not supported: {}".format(
383
+ self.last_activation
384
+ )
385
+ )
386
+
387
+ def shortcut(self, x, seg):
388
+ if self.learned_shortcut:
389
+ x_s = self.conv_s(self.norm_s(x, seg))
390
+ else:
391
+ x_s = x
392
+ return x_s
393
+
394
+ def activation(self, x):
395
+ return F.leaky_relu(x, 2e-1)
396
+
397
+ def __str__(self):
398
+ return strings.spaderesblock(self)
climategan/bn_fusion.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from copy import deepcopy
3
+
4
+
5
+ class FlattableModel(object):
6
+ def __init__(self, model):
7
+ self.model = deepcopy(model)
8
+ self._original_model = model
9
+ self._flat_model = None
10
+ self._attr_names = self.get_attributes_name()
11
+
12
+ def flatten_model(self):
13
+ if self._flat_model is None:
14
+ self._flat_model = self._flatten_model(self.model)
15
+ return self._flat_model
16
+
17
+ @staticmethod
18
+ def _selection_method(module):
19
+ return not (
20
+ isinstance(module, torch.nn.Sequential)
21
+ or isinstance(module, torch.nn.ModuleList)
22
+ ) and not hasattr(module, "_restricted")
23
+
24
+ @staticmethod
25
+ def _flatten_model(module):
26
+ modules = []
27
+ child = False
28
+ for (name, c) in module.named_children():
29
+ child = True
30
+ flattened_c = FlattableModel._flatten_model(c)
31
+ modules += flattened_c
32
+ if not child and FlattableModel._selection_method(module):
33
+ modules = [module]
34
+ return modules
35
+
36
+ def get_layer_io(self, layer, nb_samples, data_loader):
37
+ ios = []
38
+ hook = layer.register_forward_hook(
39
+ lambda m, i, o: ios.append((i[0].data.cpu(), o.data.cpu()))
40
+ )
41
+
42
+ nbatch = 1
43
+ for batch_idx, (xs, ys) in enumerate(data_loader):
44
+ # -1 takes all of them
45
+ if nb_samples != -1 and nbatch > nb_samples:
46
+ break
47
+ _ = self.model(xs.cuda())
48
+ nbatch += 1
49
+
50
+ hook.remove()
51
+ return ios
52
+
53
+ def get_attributes_name(self):
54
+ def _real_get_attributes_name(module):
55
+ modules = []
56
+ child = False
57
+ for (name, c) in module.named_children():
58
+ child = True
59
+ flattened_c = _real_get_attributes_name(c)
60
+ modules += map(lambda e: [name] + e, flattened_c)
61
+ if not child and FlattableModel._selection_method(module):
62
+ modules = [[]]
63
+ return modules
64
+
65
+ return _real_get_attributes_name(self.model)
66
+
67
+ def update_model(self, flat_model):
68
+ """
69
+ Take a list representing the flatten model and rebuild its internals.
70
+ :type flat_model: List[nn.Module]
71
+ """
72
+
73
+ def _apply_changes_on_layer(block, idxs, layer):
74
+ assert len(idxs) > 0
75
+ if len(idxs) == 1:
76
+ setattr(block, idxs[0], layer)
77
+ else:
78
+ _apply_changes_on_layer(getattr(block, idxs[0]), idxs[1:], layer)
79
+
80
+ def _apply_changes_model(model_list):
81
+ for i in range(len(model_list)):
82
+ _apply_changes_on_layer(self.model, self._attr_names[i], model_list[i])
83
+
84
+ _apply_changes_model(flat_model)
85
+ self._attr_names = self.get_attributes_name()
86
+ self._flat_model = None
87
+
88
+ def cuda(self):
89
+ self.model = self.model.cuda()
90
+ return self
91
+
92
+ def cpu(self):
93
+ self.model = self.model.cpu()
94
+ return self
95
+
96
+
97
+ def bn_fuse(model):
98
+ model = model.cpu()
99
+ flattable = FlattableModel(model)
100
+ fmodel = flattable.flatten_model()
101
+
102
+ for index, item in enumerate(fmodel):
103
+ if (
104
+ isinstance(item, torch.nn.Conv2d)
105
+ and index + 1 < len(fmodel)
106
+ and isinstance(fmodel[index + 1], torch.nn.BatchNorm2d)
107
+ ):
108
+ alpha, beta = _calculate_alpha_beta(fmodel[index + 1])
109
+ if item.weight.shape[0] != alpha.shape[0]:
110
+ # this case happens if there was actually something else
111
+ # between the conv and the
112
+ # bn layer which is not picked up in flat model logic. (see densenet)
113
+ continue
114
+ item.weight.data = item.weight.data * alpha.view(-1, 1, 1, 1)
115
+ item.bias = torch.nn.Parameter(beta)
116
+ fmodel[index + 1] = _IdentityLayer()
117
+ flattable.update_model(fmodel)
118
+ return flattable.model
119
+
120
+
121
+ def _calculate_alpha_beta(batchnorm_layer):
122
+ alpha = batchnorm_layer.weight.data / (
123
+ torch.sqrt(batchnorm_layer.running_var + batchnorm_layer.eps)
124
+ )
125
+ beta = (
126
+ -(batchnorm_layer.weight.data * batchnorm_layer.running_mean)
127
+ / (torch.sqrt(batchnorm_layer.running_var + batchnorm_layer.eps))
128
+ + batchnorm_layer.bias.data
129
+ )
130
+ alpha = alpha.cpu()
131
+ beta = beta.cpu()
132
+ return alpha, beta
133
+
134
+
135
+ class _IdentityLayer(torch.nn.Module):
136
+ def forward(self, input):
137
+ return input
climategan/data.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data-loading functions in order to create a Dataset and DataLoaders.
2
+ Transforms for loaders are in transforms.py
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+ import yaml
12
+ from imageio import imread
13
+ from PIL import Image
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torchvision import transforms
16
+
17
+ from climategan.transforms import get_transforms
18
+ from climategan.tutils import get_normalized_depth_t
19
+ from climategan.utils import env_to_path, is_image_file
20
+
21
+ classes_dict = {
22
+ "s": { # unity
23
+ 0: [0, 0, 255, 255], # Water
24
+ 1: [55, 55, 55, 255], # Ground
25
+ 2: [0, 255, 255, 255], # Building
26
+ 3: [255, 212, 0, 255], # Traffic items
27
+ 4: [0, 255, 0, 255], # Vegetation
28
+ 5: [255, 97, 0, 255], # Terrain
29
+ 6: [255, 0, 0, 255], # Car
30
+ 7: [60, 180, 60, 255], # Trees
31
+ 8: [255, 0, 255, 255], # Person
32
+ 9: [0, 0, 0, 255], # Sky
33
+ 10: [255, 255, 255, 255], # Default
34
+ },
35
+ "r": { # deeplab v2
36
+ 0: [0, 0, 255, 255], # Water
37
+ 1: [55, 55, 55, 255], # Ground
38
+ 2: [0, 255, 255, 255], # Building
39
+ 3: [255, 212, 0, 255], # Traffic items
40
+ 4: [0, 255, 0, 255], # Vegetation
41
+ 5: [255, 97, 0, 255], # Terrain
42
+ 6: [255, 0, 0, 255], # Car
43
+ 7: [60, 180, 60, 255], # Trees
44
+ 8: [220, 20, 60, 255], # Person
45
+ 9: [8, 19, 49, 255], # Sky
46
+ 10: [0, 80, 100, 255], # Default
47
+ },
48
+ "kitti": {
49
+ 0: [210, 0, 200], # Terrain
50
+ 1: [90, 200, 255], # Sky
51
+ 2: [0, 199, 0], # Tree
52
+ 3: [90, 240, 0], # Vegetation
53
+ 4: [140, 140, 140], # Building
54
+ 5: [100, 60, 100], # Road
55
+ 6: [250, 100, 255], # GuardRail
56
+ 7: [255, 255, 0], # TrafficSign
57
+ 8: [200, 200, 0], # TrafficLight
58
+ 9: [255, 130, 0], # Pole
59
+ 10: [80, 80, 80], # Misc
60
+ 11: [160, 60, 60], # Truck
61
+ 12: [255, 127, 80], # Car
62
+ 13: [0, 139, 139], # Van
63
+ 14: [0, 0, 0], # Undefined
64
+ },
65
+ "flood": {
66
+ 0: [255, 0, 0], # Cannot flood
67
+ 1: [0, 0, 255], # Must flood
68
+ 2: [0, 0, 0], # May flood
69
+ },
70
+ }
71
+
72
+ kitti_mapping = {
73
+ 0: 5, # Terrain -> Terrain
74
+ 1: 9, # Sky -> Sky
75
+ 2: 7, # Tree -> Trees
76
+ 3: 4, # Vegetation -> Vegetation
77
+ 4: 2, # Building -> Building
78
+ 5: 1, # Road -> Ground
79
+ 6: 3, # GuardRail -> Traffic items
80
+ 7: 3, # TrafficSign -> Traffic items
81
+ 8: 3, # TrafficLight -> Traffic items
82
+ 9: 3, # Pole -> Traffic items
83
+ 10: 10, # Misc -> default
84
+ 11: 6, # Truck -> Car
85
+ 12: 6, # Car -> Car
86
+ 13: 6, # Van -> Car
87
+ 14: 10, # Undefined -> Default
88
+ }
89
+
90
+
91
+ def encode_exact_segmap(seg, classes_dict, default_value=14):
92
+ """
93
+ When the mapping (rgb -> label) is known to be exact (no approximative rgb values)
94
+ maps rgb image to segmap labels
95
+
96
+ Args:
97
+ seg (np.ndarray): H x W x 3 RGB image
98
+ classes_dict (dict): Mapping {class: rgb value}
99
+ default_value (int, optional): Value for unknown label. Defaults to 14.
100
+
101
+ Returns:
102
+ np.ndarray: Segmap as labels, not RGB
103
+ """
104
+ out = np.ones((seg.shape[0], seg.shape[1])) * default_value
105
+ for cindex, cvalue in classes_dict.items():
106
+ out[np.where((seg == cvalue).all(-1))] = cindex
107
+ return out
108
+
109
+
110
+ def merge_labels(labels, mapping, default_value=14):
111
+ """
112
+ Maps labels from a source domain to labels of a target domain,
113
+ typically kitti -> climategan
114
+
115
+ Args:
116
+ labels (np.ndarray): input segmap labels
117
+ mapping (dict): source_label -> target_label
118
+ default_value (int, optional): Unknown label. Defaults to 14.
119
+
120
+ Returns:
121
+ np.ndarray: Adapted labels
122
+ """
123
+ out = np.ones_like(labels) * default_value
124
+ for source, target in mapping.items():
125
+ out[labels == source] = target
126
+ return out
127
+
128
+
129
+ def process_kitti_seg(path, kitti_classes, merge_map, default=14):
130
+ """
131
+ Processes a path to produce a 1 x 1 x H x W torch segmap
132
+
133
+ %timeit process_kitti_seg(path, classes_dict, mapping, default=14)
134
+ 326 ms ± 118 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
135
+
136
+ Args:
137
+ path (str | pathlib.Path): Segmap RBG path
138
+ kitti_classes (dict): Kitti map label -> rgb
139
+ merge_map (dict): map kitti_label -> climategan_label
140
+ default (int, optional): Unknown kitti label. Defaults to 14.
141
+
142
+ Returns:
143
+ torch.Tensor: 1 x 1 x H x W torch segmap
144
+ """
145
+ seg = imread(path)
146
+ labels = encode_exact_segmap(seg, kitti_classes, default_value=default)
147
+ merged = merge_labels(labels, merge_map, default_value=default)
148
+ return torch.tensor(merged).unsqueeze(0).unsqueeze(0)
149
+
150
+
151
+ def decode_segmap_merged_labels(tensor, domain, is_target, nc=11):
152
+ """Creates a label colormap for classes used in Unity segmentation benchmark.
153
+ Arguments:
154
+ tensor -- segmented image of size (1) x (nc) x (H) x (W)
155
+ if prediction, or size (1) x (1) x (H) x (W) if target
156
+ Returns:
157
+ RGB tensor of size (1) x (3) x (H) x (W)
158
+ #"""
159
+
160
+ if is_target: # Target is size 1 x 1 x H x W
161
+ idx = tensor.squeeze(0).squeeze(0)
162
+ else: # Prediction is size 1 x nc x H x W
163
+ idx = torch.argmax(tensor.squeeze(0), dim=0)
164
+
165
+ indexer = torch.tensor(list(classes_dict[domain].values()))[:, :3]
166
+ return indexer[idx.long()].permute(2, 0, 1).to(torch.float32).unsqueeze(0)
167
+
168
+
169
+ def decode_segmap_cityscapes_labels(image, nc=19):
170
+ """Creates a label colormap used in CITYSCAPES segmentation benchmark.
171
+ Arguments:
172
+ image {array} -- segmented image
173
+ (array of image size containing class at each pixel)
174
+ Returns:
175
+ array of size 3*nc -- A colormap for visualizing segmentation results.
176
+ """
177
+ colormap = np.zeros((19, 3), dtype=np.uint8)
178
+ colormap[0] = [128, 64, 128]
179
+ colormap[1] = [244, 35, 232]
180
+ colormap[2] = [70, 70, 70]
181
+ colormap[3] = [102, 102, 156]
182
+ colormap[4] = [190, 153, 153]
183
+ colormap[5] = [153, 153, 153]
184
+ colormap[6] = [250, 170, 30]
185
+ colormap[7] = [220, 220, 0]
186
+ colormap[8] = [107, 142, 35]
187
+ colormap[9] = [152, 251, 152]
188
+ colormap[10] = [70, 130, 180]
189
+ colormap[11] = [220, 20, 60]
190
+ colormap[12] = [255, 0, 0]
191
+ colormap[13] = [0, 0, 142]
192
+ colormap[14] = [0, 0, 70]
193
+ colormap[15] = [0, 60, 100]
194
+ colormap[16] = [0, 80, 100]
195
+ colormap[17] = [0, 0, 230]
196
+ colormap[18] = [119, 11, 32]
197
+
198
+ r = np.zeros_like(image).astype(np.uint8)
199
+ g = np.zeros_like(image).astype(np.uint8)
200
+ b = np.zeros_like(image).astype(np.uint8)
201
+
202
+ for col in range(nc):
203
+ idx = image == col
204
+ r[idx] = colormap[col, 0]
205
+ g[idx] = colormap[col, 1]
206
+ b[idx] = colormap[col, 2]
207
+
208
+ rgb = np.stack([r, g, b], axis=2)
209
+ return rgb
210
+
211
+
212
+ def find_closest_class(pixel, dict_classes):
213
+ """Takes a pixel as input and finds the closest known pixel value corresponding
214
+ to a class in dict_classes
215
+
216
+ Arguments:
217
+ pixel -- tuple pixel (R,G,B,A)
218
+ Returns:
219
+ tuple pixel (R,G,B,A) corresponding to a key (a class) in dict_classes
220
+ """
221
+ min_dist = float("inf")
222
+ closest_pixel = None
223
+ for pixel_value in dict_classes.keys():
224
+ dist = np.sqrt(np.sum(np.square(np.subtract(pixel, pixel_value))))
225
+ if dist < min_dist:
226
+ min_dist = dist
227
+ closest_pixel = pixel_value
228
+ return closest_pixel
229
+
230
+
231
+ def encode_segmap(arr, domain):
232
+ """Change a segmentation RGBA array to a segmentation array
233
+ with each pixel being the index of the class
234
+ Arguments:
235
+ numpy array -- segmented image of size (H) x (W) x (4 RGBA values)
236
+ Returns:
237
+ numpy array of size (1) x (H) x (W) with each pixel being the index of the class
238
+ """
239
+ new_arr = np.zeros((1, arr.shape[0], arr.shape[1]))
240
+ dict_classes = {
241
+ tuple(rgba_value): class_id
242
+ for (class_id, rgba_value) in classes_dict[domain].items()
243
+ }
244
+ for i in range(arr.shape[0]):
245
+ for j in range(arr.shape[1]):
246
+ pixel_rgba = tuple(arr[i, j, :])
247
+ if pixel_rgba in dict_classes.keys():
248
+ new_arr[0, i, j] = dict_classes[pixel_rgba]
249
+ else:
250
+ pixel_rgba_closest = find_closest_class(pixel_rgba, dict_classes)
251
+ new_arr[0, i, j] = dict_classes[pixel_rgba_closest]
252
+ return new_arr
253
+
254
+
255
+ def encode_mask_label(arr, domain):
256
+ """Change a segmentation RGBA array to a segmentation array
257
+ with each pixel being the index of the class
258
+ Arguments:
259
+ numpy array -- segmented image of size (H) x (W) x (3 RGB values)
260
+ Returns:
261
+ numpy array of size (1) x (H) x (W) with each pixel being the index of the class
262
+ """
263
+ diff = np.zeros((len(classes_dict[domain].keys()), arr.shape[0], arr.shape[1]))
264
+ for cindex, cvalue in classes_dict[domain].items():
265
+ diff[cindex, :, :] = np.sqrt(
266
+ np.sum(
267
+ np.square(arr - np.tile(cvalue, (arr.shape[0], arr.shape[1], 1))),
268
+ axis=2,
269
+ )
270
+ )
271
+ return np.expand_dims(np.argmin(diff, axis=0), axis=0)
272
+
273
+
274
+ def transform_segmap_image_to_tensor(path, domain):
275
+ """
276
+ Transforms a segmentation image to a tensor of size (1) x (1) x (H) x (W)
277
+ with each pixel being the index of the class
278
+ """
279
+ arr = np.array(Image.open(path).convert("RGBA"))
280
+ arr = encode_segmap(arr, domain)
281
+ arr = torch.from_numpy(arr).float()
282
+ arr = arr.unsqueeze(0)
283
+ return arr
284
+
285
+
286
+ def save_segmap_tensors(path_to_json, path_to_dir, domain):
287
+ """
288
+ Loads the segmentation images mentionned in a json file, transforms them to
289
+ tensors and save the tensors in the wanted directory
290
+
291
+ Args:
292
+ path_to_json: complete path to the json file where to find the original data
293
+ path_to_dir: path to the directory where to save the tensors as tensor_name.pt
294
+ domain: domain of the images ("r" or "s")
295
+
296
+ e.g:
297
+ save_tensors(
298
+ "/network/tmp1/ccai/data/climategan/seg/train_s.json",
299
+ "/network/tmp1/ccai/data/munit_dataset/simdata/Unity11K_res640/Seg_tensors/",
300
+ "s",
301
+ )
302
+ """
303
+ ims_list = None
304
+ if path_to_json:
305
+ path_to_json = Path(path_to_json).resolve()
306
+ with open(path_to_json, "r") as f:
307
+ ims_list = yaml.safe_load(f)
308
+
309
+ assert ims_list is not None
310
+
311
+ for im_dict in ims_list:
312
+ for task_name, path in im_dict.items():
313
+ if task_name == "s":
314
+ file_name = os.path.splitext(path)[0] # remove extension
315
+ file_name = file_name.rsplit("/", 1)[-1] # keep only the file_name
316
+ tensor = transform_segmap_image_to_tensor(path, domain)
317
+ torch.save(tensor, path_to_dir + file_name + ".pt")
318
+
319
+
320
+ def pil_image_loader(path, task):
321
+ if Path(path).suffix == ".npy":
322
+ arr = np.load(path).astype(np.uint8)
323
+ elif is_image_file(path):
324
+ # arr = imread(path).astype(np.uint8)
325
+ arr = np.array(Image.open(path).convert("RGB"))
326
+ else:
327
+ raise ValueError("Unknown data type {}".format(path))
328
+
329
+ # Convert from RGBA to RGB for images
330
+ if len(arr.shape) == 3 and arr.shape[-1] == 4:
331
+ arr = arr[:, :, 0:3]
332
+
333
+ if task == "m":
334
+ arr[arr != 0] = 1
335
+ # Make sure mask is single-channel
336
+ if len(arr.shape) >= 3:
337
+ arr = arr[:, :, 0]
338
+
339
+ # assert len(arr.shape) == 3, (path, task, arr.shape)
340
+
341
+ return Image.fromarray(arr)
342
+
343
+
344
+ def tensor_loader(path, task, domain, opts):
345
+ """load data as tensors
346
+ Args:
347
+ path (str): path to data
348
+ task (str)
349
+ domain (str)
350
+ Returns:
351
+ [Tensor]: 1 x C x H x W
352
+ """
353
+ if task == "s":
354
+ if domain == "kitti":
355
+ return process_kitti_seg(
356
+ path, classes_dict["kitti"], kitti_mapping, default=14
357
+ )
358
+ return torch.load(path)
359
+ elif task == "d":
360
+ if Path(path).suffix == ".npy":
361
+ arr = np.load(path)
362
+ else:
363
+ arr = imread(path) # .astype(np.uint8) /!\ kitti is np.uint16
364
+ tensor = torch.from_numpy(arr.astype(np.float32))
365
+ tensor = get_normalized_depth_t(
366
+ tensor,
367
+ domain,
368
+ normalize="d" in opts.train.pseudo.tasks,
369
+ log=opts.gen.d.classify.enable,
370
+ )
371
+ tensor = tensor.unsqueeze(0)
372
+ return tensor
373
+
374
+ elif Path(path).suffix == ".npy":
375
+ arr = np.load(path).astype(np.float32)
376
+ elif is_image_file(path):
377
+ arr = imread(path).astype(np.float32)
378
+ else:
379
+ raise ValueError("Unknown data type {}".format(path))
380
+
381
+ # Convert from RGBA to RGB for images
382
+ if len(arr.shape) == 3 and arr.shape[-1] == 4:
383
+ arr = arr[:, :, 0:3]
384
+
385
+ if task == "x":
386
+ arr -= arr.min()
387
+ arr /= arr.max()
388
+ arr = np.moveaxis(arr, 2, 0)
389
+ elif task == "s":
390
+ arr = np.moveaxis(arr, 2, 0)
391
+ elif task == "m":
392
+ if arr.max() > 127:
393
+ arr = (arr > 127).astype(arr.dtype)
394
+ # Make sure mask is single-channel
395
+ if len(arr.shape) >= 3:
396
+ arr = arr[:, :, 0]
397
+ arr = np.expand_dims(arr, 0)
398
+
399
+ return torch.from_numpy(arr).unsqueeze(0)
400
+
401
+
402
+ class OmniListDataset(Dataset):
403
+ def __init__(self, mode, domain, opts, transform=None):
404
+
405
+ self.opts = opts
406
+ self.domain = domain
407
+ self.mode = mode
408
+ self.tasks = set(opts.tasks)
409
+ self.tasks.add("x")
410
+ if "p" in self.tasks:
411
+ self.tasks.add("m")
412
+
413
+ file_list_path = Path(opts.data.files[mode][domain])
414
+ if "/" not in str(file_list_path):
415
+ file_list_path = Path(opts.data.files.base) / Path(
416
+ opts.data.files[mode][domain]
417
+ )
418
+
419
+ if file_list_path.suffix == ".json":
420
+ self.samples_paths = self.json_load(file_list_path)
421
+ elif file_list_path.suffix in {".yaml", ".yml"}:
422
+ self.samples_paths = self.yaml_load(file_list_path)
423
+ else:
424
+ raise ValueError("Unknown file list type in {}".format(file_list_path))
425
+
426
+ if opts.data.max_samples and opts.data.max_samples != -1:
427
+ assert isinstance(opts.data.max_samples, int)
428
+ self.samples_paths = self.samples_paths[: opts.data.max_samples]
429
+
430
+ self.filter_samples()
431
+ if opts.data.check_samples:
432
+ print(f"Checking samples ({mode}, {domain})")
433
+ self.check_samples()
434
+ self.file_list_path = str(file_list_path)
435
+ self.transform = transform
436
+
437
+ def filter_samples(self):
438
+ """
439
+ Filter out data which is not required for the model's tasks
440
+ as defined in opts.tasks
441
+ """
442
+ self.samples_paths = [
443
+ {k: v for k, v in s.items() if k in self.tasks} for s in self.samples_paths
444
+ ]
445
+
446
+ def __getitem__(self, i):
447
+ """Return an item in the dataset with fields:
448
+ {
449
+ data: transform({
450
+ domains: values
451
+ }),
452
+ paths: [{task: path}],
453
+ domain: [domain],
454
+ mode: [train|val]
455
+ }
456
+ Args:
457
+ i (int): index of item to retrieve
458
+ Returns:
459
+ dict: dataset item where tensors of data are in item["data"] which is a dict
460
+ {task: tensor}
461
+ """
462
+ paths = self.samples_paths[i]
463
+
464
+ # always apply transforms,
465
+ # if no transform is specified, ToTensor and Normalize will be applied
466
+
467
+ item = {
468
+ "data": self.transform(
469
+ {
470
+ task: tensor_loader(
471
+ env_to_path(path),
472
+ task,
473
+ self.domain,
474
+ self.opts,
475
+ )
476
+ for task, path in paths.items()
477
+ }
478
+ ),
479
+ "paths": paths,
480
+ "domain": self.domain if self.domain != "kitti" else "s",
481
+ "mode": self.mode,
482
+ }
483
+
484
+ return item
485
+
486
+ def __len__(self):
487
+ return len(self.samples_paths)
488
+
489
+ def json_load(self, file_path):
490
+ with open(file_path, "r") as f:
491
+ return json.load(f)
492
+
493
+ def yaml_load(self, file_path):
494
+ with open(file_path, "r") as f:
495
+ return yaml.safe_load(f)
496
+
497
+ def check_samples(self):
498
+ """Checks that every file listed in samples_paths actually
499
+ exist on the file-system
500
+ """
501
+ for s in self.samples_paths:
502
+ for k, v in s.items():
503
+ assert Path(v).exists(), f"{k} {v} does not exist"
504
+
505
+
506
+ def get_loader(mode, domain, opts):
507
+ if (
508
+ domain != "kitti"
509
+ or not opts.train.kitti.pretrain
510
+ or not opts.train.kitti.batch_size
511
+ ):
512
+ batch_size = opts.data.loaders.get("batch_size", 4)
513
+ else:
514
+ batch_size = opts.train.kitti.get("batch_size", 4)
515
+
516
+ return DataLoader(
517
+ OmniListDataset(
518
+ mode,
519
+ domain,
520
+ opts,
521
+ transform=transforms.Compose(get_transforms(opts, mode, domain)),
522
+ ),
523
+ batch_size=batch_size,
524
+ shuffle=True,
525
+ num_workers=opts.data.loaders.get("num_workers", 8),
526
+ pin_memory=True, # faster transfer to gpu
527
+ drop_last=True, # avoids batchnorm pbs if last batch has size 1
528
+ )
529
+
530
+
531
+ def get_all_loaders(opts):
532
+ loaders = {}
533
+ for mode in ["train", "val"]:
534
+ loaders[mode] = {}
535
+ for domain in opts.domains:
536
+ if mode in opts.data.files:
537
+ if domain in opts.data.files[mode]:
538
+ loaders[mode][domain] = get_loader(mode, domain, opts)
539
+ return loaders
climategan/deeplab/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from climategan.deeplab.deeplab_v2 import DeepLabV2Decoder
6
+ from climategan.deeplab.deeplab_v3 import DeepLabV3Decoder
7
+ from climategan.deeplab.mobilenet_v3 import MobileNetV2
8
+ from climategan.deeplab.resnet101_v3 import ResNet101
9
+ from climategan.deeplab.resnetmulti_v2 import ResNetMulti
10
+
11
+
12
+ def create_encoder(opts, no_init=False, verbose=0):
13
+ if opts.gen.encoder.architecture == "deeplabv2":
14
+ if verbose > 0:
15
+ print(" - Add Deeplabv2 Encoder")
16
+ return DeeplabV2Encoder(opts, no_init, verbose)
17
+ elif opts.gen.encoder.architecture == "deeplabv3":
18
+ if verbose > 0:
19
+ backone = opts.gen.deeplabv3.backbone
20
+ print(" - Add Deeplabv3 ({}) Encoder".format(backone))
21
+ return build_v3_backbone(opts, no_init)
22
+ else:
23
+ raise NotImplementedError(
24
+ "Unknown encoder: {}".format(opts.gen.encoder.architecture)
25
+ )
26
+
27
+
28
+ def create_segmentation_decoder(opts, no_init=False, verbose=0):
29
+ if opts.gen.s.architecture == "deeplabv2":
30
+ if verbose > 0:
31
+ print(" - Add DeepLabV2Decoder")
32
+ return DeepLabV2Decoder(opts)
33
+ elif opts.gen.s.architecture == "deeplabv3":
34
+ if verbose > 0:
35
+ print(" - Add DeepLabV3Decoder")
36
+ return DeepLabV3Decoder(opts, no_init)
37
+ else:
38
+ raise NotImplementedError(
39
+ "Unknown Segmentation architecture: {}".format(opts.gen.s.architecture)
40
+ )
41
+
42
+
43
+ def build_v3_backbone(opts, no_init, verbose=0):
44
+ backbone = opts.gen.deeplabv3.backbone
45
+ output_stride = opts.gen.deeplabv3.output_stride
46
+ if backbone == "resnet":
47
+ resnet = ResNet101(
48
+ output_stride=output_stride,
49
+ BatchNorm=nn.BatchNorm2d,
50
+ verbose=verbose,
51
+ no_init=no_init,
52
+ )
53
+ if not no_init:
54
+ if opts.gen.deeplabv3.backbone == "resnet":
55
+ assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists()
56
+
57
+ std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
58
+ resnet.load_state_dict(
59
+ {
60
+ k.replace("backbone.", ""): v
61
+ for k, v in std.items()
62
+ if k.startswith("backbone.")
63
+ }
64
+ )
65
+ print(
66
+ " - Loaded pre-trained DeepLabv3+ Resnet101 Backbone as Encoder"
67
+ )
68
+ return resnet
69
+
70
+ elif opts.gen.deeplabv3.backbone == "mobilenet":
71
+ assert Path(opts.gen.deeplabv3.pretrained_model.mobilenet).exists()
72
+ mobilenet = MobileNetV2(
73
+ no_init=no_init,
74
+ pretrained_path=opts.gen.deeplabv3.pretrained_model.mobilenet,
75
+ )
76
+ print(" - Loaded pre-trained DeepLabv3+ MobileNetV2 Backbone as Encoder")
77
+ return mobilenet
78
+
79
+ else:
80
+ raise NotImplementedError("Unknown backbone in " + str(opts.gen.deeplabv3))
81
+
82
+
83
+ class DeeplabV2Encoder(nn.Module):
84
+ def __init__(self, opts, no_init=False, verbose=0):
85
+ """Deeplab architecture encoder"""
86
+ super().__init__()
87
+
88
+ self.model = ResNetMulti(opts.gen.deeplabv2.nblocks, opts.gen.encoder.n_res)
89
+ if opts.gen.deeplabv2.use_pretrained and not no_init:
90
+ saved_state_dict = torch.load(opts.gen.deeplabv2.pretrained_model)
91
+ new_params = self.model.state_dict().copy()
92
+ for i in saved_state_dict:
93
+ i_parts = i.split(".")
94
+ if not i_parts[1] in ["layer5", "resblock"]:
95
+ new_params[".".join(i_parts[1:])] = saved_state_dict[i]
96
+ self.model.load_state_dict(new_params)
97
+ if verbose > 0:
98
+ print(" - Loaded pretrained weights")
99
+
100
+ def forward(self, x):
101
+ return self.model(x)
climategan/deeplab/deeplab_v2.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from climategan.blocks import InterpolateNearest2d
5
+ from climategan.utils import find_target_size
6
+
7
+
8
+ class _ASPPModule(nn.Module):
9
+ # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/aspp.py
10
+ def __init__(
11
+ self, inplanes, planes, kernel_size, padding, dilation, BatchNorm, no_init
12
+ ):
13
+ super().__init__()
14
+ self.atrous_conv = nn.Conv2d(
15
+ inplanes,
16
+ planes,
17
+ kernel_size=kernel_size,
18
+ stride=1,
19
+ padding=padding,
20
+ dilation=dilation,
21
+ bias=False,
22
+ )
23
+ self.bn = BatchNorm(planes)
24
+ self.relu = nn.ReLU()
25
+ if not no_init:
26
+ self._init_weight()
27
+
28
+ def forward(self, x):
29
+ x = self.atrous_conv(x)
30
+ x = self.bn(x)
31
+
32
+ return self.relu(x)
33
+
34
+ def _init_weight(self):
35
+ for m in self.modules():
36
+ if isinstance(m, nn.Conv2d):
37
+ torch.nn.init.kaiming_normal_(m.weight)
38
+ elif isinstance(m, nn.BatchNorm2d):
39
+ m.weight.data.fill_(1)
40
+ m.bias.data.zero_()
41
+
42
+
43
+ class ASPP(nn.Module):
44
+ # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/aspp.py
45
+ def __init__(self, backbone, output_stride, BatchNorm, no_init):
46
+ super().__init__()
47
+
48
+ if backbone == "mobilenet":
49
+ inplanes = 320
50
+ else:
51
+ inplanes = 2048
52
+
53
+ if output_stride == 16:
54
+ dilations = [1, 6, 12, 18]
55
+ elif output_stride == 8:
56
+ dilations = [1, 12, 24, 36]
57
+ else:
58
+ raise NotImplementedError
59
+
60
+ self.aspp1 = _ASPPModule(
61
+ inplanes,
62
+ 256,
63
+ 1,
64
+ padding=0,
65
+ dilation=dilations[0],
66
+ BatchNorm=BatchNorm,
67
+ no_init=no_init,
68
+ )
69
+ self.aspp2 = _ASPPModule(
70
+ inplanes,
71
+ 256,
72
+ 3,
73
+ padding=dilations[1],
74
+ dilation=dilations[1],
75
+ BatchNorm=BatchNorm,
76
+ no_init=no_init,
77
+ )
78
+ self.aspp3 = _ASPPModule(
79
+ inplanes,
80
+ 256,
81
+ 3,
82
+ padding=dilations[2],
83
+ dilation=dilations[2],
84
+ BatchNorm=BatchNorm,
85
+ no_init=no_init,
86
+ )
87
+ self.aspp4 = _ASPPModule(
88
+ inplanes,
89
+ 256,
90
+ 3,
91
+ padding=dilations[3],
92
+ dilation=dilations[3],
93
+ BatchNorm=BatchNorm,
94
+ no_init=no_init,
95
+ )
96
+
97
+ self.global_avg_pool = nn.Sequential(
98
+ nn.AdaptiveAvgPool2d((1, 1)),
99
+ nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
100
+ BatchNorm(256),
101
+ nn.ReLU(),
102
+ )
103
+ self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
104
+ self.bn1 = BatchNorm(256)
105
+ self.relu = nn.ReLU()
106
+ self.dropout = nn.Dropout(0.5)
107
+ if not no_init:
108
+ self._init_weight()
109
+
110
+ def forward(self, x):
111
+ x1 = self.aspp1(x)
112
+ x2 = self.aspp2(x)
113
+ x3 = self.aspp3(x)
114
+ x4 = self.aspp4(x)
115
+ x5 = self.global_avg_pool(x)
116
+ x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
117
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
118
+
119
+ x = self.conv1(x)
120
+ x = self.bn1(x)
121
+ x = self.relu(x)
122
+
123
+ return self.dropout(x)
124
+
125
+ def _init_weight(self):
126
+ for m in self.modules():
127
+ if isinstance(m, nn.Conv2d):
128
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
129
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
130
+ torch.nn.init.kaiming_normal_(m.weight)
131
+ elif isinstance(m, nn.BatchNorm2d):
132
+ m.weight.data.fill_(1)
133
+ m.bias.data.zero_()
134
+
135
+
136
+ class DeepLabV2Decoder(nn.Module):
137
+ # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/decoder.py
138
+ # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py
139
+ def __init__(self, opts, no_init=False):
140
+ super().__init__()
141
+ self.aspp = ASPP("resnet", 16, nn.BatchNorm2d, no_init)
142
+ self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada
143
+
144
+ conv_modules = [
145
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
146
+ nn.BatchNorm2d(256),
147
+ nn.ReLU(),
148
+ nn.Dropout(0.5),
149
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
150
+ nn.BatchNorm2d(256),
151
+ nn.ReLU(),
152
+ nn.Dropout(0.1),
153
+ ]
154
+ if opts.gen.s.upsample_featuremaps:
155
+ conv_modules = [InterpolateNearest2d(scale_factor=2)] + conv_modules
156
+
157
+ conv_modules += [
158
+ nn.Conv2d(256, opts.gen.s.output_dim, kernel_size=1, stride=1),
159
+ ]
160
+ self.conv = nn.Sequential(*conv_modules)
161
+
162
+ self._target_size = find_target_size(opts, "s")
163
+ print(
164
+ " - {}: setting target size to {}".format(
165
+ self.__class__.__name__, self._target_size
166
+ )
167
+ )
168
+
169
+ def set_target_size(self, size):
170
+ """
171
+ Set final interpolation's target size
172
+
173
+ Args:
174
+ size (int, list, tuple): target size (h, w). If int, target will be (i, i)
175
+ """
176
+ if isinstance(size, (list, tuple)):
177
+ self._target_size = size[:2]
178
+ else:
179
+ self._target_size = (size, size)
180
+
181
+ def forward(self, z, z_depth=None):
182
+ if self._target_size is None:
183
+ error = "self._target_size should be set with self.set_target_size()"
184
+ error += "to interpolate logits to the target seg map's size"
185
+ raise Exception(error)
186
+ if isinstance(z, (list, tuple)):
187
+ z = z[0]
188
+ if z.shape[1] != 2048:
189
+ raise Exception(
190
+ "Segmentation decoder will only work with 2048 channels for z"
191
+ )
192
+
193
+ if z_depth is not None and self.use_dada:
194
+ z = z * z_depth
195
+
196
+ y = self.aspp(z)
197
+ y = self.conv(y)
198
+ return F.interpolate(y, self._target_size, mode="bilinear", align_corners=True)
climategan/deeplab/deeplab_v3.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/resnet.py
3
+ """
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from climategan.deeplab.mobilenet_v3 import SeparableConv2d
10
+ from climategan.utils import find_target_size
11
+
12
+
13
+ class _DeepLabHead(nn.Module):
14
+ def __init__(
15
+ self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d
16
+ ):
17
+ super().__init__()
18
+ last_channels = c4_channels
19
+ # self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer)
20
+ # last_channels += 48
21
+ self.block = nn.Sequential(
22
+ SeparableConv2d(
23
+ last_channels, 256, 3, norm_layer=norm_layer, relu_first=False
24
+ ),
25
+ SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False),
26
+ nn.Conv2d(256, nclass, 1),
27
+ )
28
+
29
+ def forward(self, x, c1=None):
30
+ return self.block(x)
31
+
32
+
33
+ class ConvBNReLU(nn.Module):
34
+ """
35
+ https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
36
+ """
37
+
38
+ def __init__(
39
+ self, in_chan, out_chan, ks=3, stride=1, padding=1, dilation=1, *args, **kwargs
40
+ ):
41
+ super().__init__()
42
+ self.conv = nn.Conv2d(
43
+ in_chan,
44
+ out_chan,
45
+ kernel_size=ks,
46
+ stride=stride,
47
+ padding=padding,
48
+ dilation=dilation,
49
+ bias=True,
50
+ )
51
+ self.bn = nn.BatchNorm2d(out_chan)
52
+ self.init_weight()
53
+
54
+ def forward(self, x):
55
+ x = self.conv(x)
56
+ x = self.bn(x)
57
+ return x
58
+
59
+ def init_weight(self):
60
+ for ly in self.children():
61
+ if isinstance(ly, nn.Conv2d):
62
+ nn.init.kaiming_normal_(ly.weight, a=1)
63
+ if ly.bias is not None:
64
+ nn.init.constant_(ly.bias, 0)
65
+
66
+
67
+ class ASPPv3Plus(nn.Module):
68
+ """
69
+ https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
70
+ """
71
+
72
+ def __init__(self, backbone, no_init):
73
+ super().__init__()
74
+
75
+ if backbone == "mobilenet":
76
+ in_chan = 320
77
+ else:
78
+ in_chan = 2048
79
+
80
+ self.with_gp = False
81
+ self.conv1 = ConvBNReLU(in_chan, 256, ks=1, dilation=1, padding=0)
82
+ self.conv2 = ConvBNReLU(in_chan, 256, ks=3, dilation=6, padding=6)
83
+ self.conv3 = ConvBNReLU(in_chan, 256, ks=3, dilation=12, padding=12)
84
+ self.conv4 = ConvBNReLU(in_chan, 256, ks=3, dilation=18, padding=18)
85
+ if self.with_gp:
86
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
87
+ self.conv1x1 = ConvBNReLU(in_chan, 256, ks=1)
88
+ self.conv_out = ConvBNReLU(256 * 5, 256, ks=1)
89
+ else:
90
+ self.conv_out = ConvBNReLU(256 * 4, 256, ks=1)
91
+
92
+ if not no_init:
93
+ self.init_weight()
94
+
95
+ def forward(self, x):
96
+ H, W = x.size()[2:]
97
+ feat1 = self.conv1(x)
98
+ feat2 = self.conv2(x)
99
+ feat3 = self.conv3(x)
100
+ feat4 = self.conv4(x)
101
+ if self.with_gp:
102
+ avg = self.avg(x)
103
+ feat5 = self.conv1x1(avg)
104
+ feat5 = F.interpolate(feat5, (H, W), mode="bilinear", align_corners=True)
105
+ feat = torch.cat([feat1, feat2, feat3, feat4, feat5], 1)
106
+ else:
107
+ feat = torch.cat([feat1, feat2, feat3, feat4], 1)
108
+ feat = self.conv_out(feat)
109
+ return feat
110
+
111
+ def init_weight(self):
112
+ for ly in self.children():
113
+ if isinstance(ly, nn.Conv2d):
114
+ nn.init.kaiming_normal_(ly.weight, a=1)
115
+ if ly.bias is not None:
116
+ nn.init.constant_(ly.bias, 0)
117
+
118
+
119
+ class Decoder(nn.Module):
120
+ """
121
+ https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
122
+ """
123
+
124
+ def __init__(self, n_classes):
125
+ super(Decoder, self).__init__()
126
+ self.conv_low = ConvBNReLU(256, 48, ks=1, padding=0)
127
+ self.conv_cat = nn.Sequential(
128
+ ConvBNReLU(304, 256, ks=3, padding=1),
129
+ ConvBNReLU(256, 256, ks=3, padding=1),
130
+ )
131
+ self.conv_out = nn.Conv2d(256, n_classes, kernel_size=1, bias=False)
132
+
133
+ def forward(self, feat_low, feat_aspp):
134
+ H, W = feat_low.size()[2:]
135
+ feat_low = self.conv_low(feat_low)
136
+ feat_aspp_up = F.interpolate(
137
+ feat_aspp, (H, W), mode="bilinear", align_corners=True
138
+ )
139
+ feat_cat = torch.cat([feat_low, feat_aspp_up], dim=1)
140
+ feat_out = self.conv_cat(feat_cat)
141
+ logits = self.conv_out(feat_out)
142
+ return logits
143
+
144
+
145
+ """
146
+ https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py
147
+ """
148
+
149
+
150
+ class DeepLabV3Decoder(nn.Module):
151
+ def __init__(
152
+ self,
153
+ opts,
154
+ no_init=False,
155
+ freeze_bn=False,
156
+ ):
157
+ super().__init__()
158
+
159
+ num_classes = opts.gen.s.output_dim
160
+ self.backbone = opts.gen.deeplabv3.backbone
161
+ self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada
162
+
163
+ if self.backbone == "resnet":
164
+ self.aspp = ASPPv3Plus(self.backbone, no_init)
165
+ self.decoder = Decoder(num_classes)
166
+
167
+ self.freeze_bn = freeze_bn
168
+ else:
169
+ self.head = _DeepLabHead(num_classes, c4_channels=320)
170
+
171
+ self._target_size = find_target_size(opts, "s")
172
+ print(
173
+ " - {}: setting target size to {}".format(
174
+ self.__class__.__name__, self._target_size
175
+ )
176
+ )
177
+
178
+ if not no_init:
179
+ for m in self.modules():
180
+ if isinstance(m, nn.Conv2d):
181
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
182
+ if m.bias is not None:
183
+ nn.init.zeros_(m.bias)
184
+ elif isinstance(m, nn.BatchNorm2d):
185
+ nn.init.ones_(m.weight)
186
+ nn.init.zeros_(m.bias)
187
+ elif isinstance(m, nn.Linear):
188
+ nn.init.normal_(m.weight, 0, 0.01)
189
+ nn.init.zeros_(m.bias)
190
+
191
+ self.load_pretrained(opts)
192
+
193
+ def load_pretrained(self, opts):
194
+ assert opts.gen.deeplabv3.backbone in {"resnet", "mobilenet"}
195
+ assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists()
196
+ if opts.gen.deeplabv3.backbone == "resnet":
197
+ std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
198
+ self.aspp.load_state_dict(
199
+ {
200
+ k.replace("aspp.", ""): v
201
+ for k, v in std.items()
202
+ if k.startswith("aspp.")
203
+ }
204
+ )
205
+ self.decoder.load_state_dict(
206
+ {
207
+ k.replace("decoder.", ""): v
208
+ for k, v in std.items()
209
+ if k.startswith("decoder.")
210
+ and not (len(v.shape) > 0 and v.shape[0] == 19)
211
+ },
212
+ strict=False,
213
+ )
214
+ print(
215
+ "- Loaded pre-trained DeepLabv3+ (Resnet) Decoder & ASPP as Seg Decoder"
216
+ )
217
+ else:
218
+ std = torch.load(opts.gen.deeplabv3.pretrained_model.mobilenet)
219
+ self.load_state_dict(
220
+ {
221
+ k: v
222
+ for k, v in std.items()
223
+ if k.startswith("head.")
224
+ and not (len(v.shape) > 0 and v.shape[0] == 19)
225
+ },
226
+ strict=False,
227
+ )
228
+ print(
229
+ " - Loaded pre-trained DeepLabv3+ (MobileNetV2) Head as Seg Decoder"
230
+ )
231
+
232
+ def set_target_size(self, size):
233
+ """
234
+ Set final interpolation's target size
235
+
236
+ Args:
237
+ size (int, list, tuple): target size (h, w). If int, target will be (i, i)
238
+ """
239
+ if isinstance(size, (list, tuple)):
240
+ self._target_size = size[:2]
241
+ else:
242
+ self._target_size = (size, size)
243
+
244
+ def forward(self, z, z_depth=None):
245
+ assert isinstance(z, (tuple, list))
246
+ if self._target_size is None:
247
+ error = "self._target_size should be set with self.set_target_size()"
248
+ error += "to interpolate logits to the target seg map's size"
249
+ raise ValueError(error)
250
+
251
+ z_high, z_low = z
252
+
253
+ if z_depth is not None and self.use_dada:
254
+ z_high = z_high * z_depth
255
+
256
+ if self.backbone == "resnet":
257
+ z_high = self.aspp(z_high)
258
+ s = self.decoder(z_high, z_low)
259
+ else:
260
+ s = self.head(z_high)
261
+
262
+ s = F.interpolate(
263
+ s, size=self._target_size, mode="bilinear", align_corners=True
264
+ )
265
+
266
+ return s
267
+
268
+ def freeze_bn(self):
269
+ for m in self.modules():
270
+ if isinstance(m, nn.BatchNorm2d):
271
+ m.eval()
climategan/deeplab/mobilenet_v3.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ from https://github.com/LikeLy-Journey/SegmenTron/blob/
3
+ 4bc605eedde7d680314f63d329277b73f83b1c5f/segmentron/modules/basic.py#L34
4
+ """
5
+
6
+ from collections import OrderedDict
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from climategan.blocks import InterpolateNearest2d
12
+
13
+
14
+ class SeparableConv2d(nn.Module):
15
+ def __init__(
16
+ self,
17
+ inplanes,
18
+ planes,
19
+ kernel_size=3,
20
+ stride=1,
21
+ dilation=1,
22
+ relu_first=True,
23
+ bias=False,
24
+ norm_layer=nn.BatchNorm2d,
25
+ ):
26
+ super().__init__()
27
+ depthwise = nn.Conv2d(
28
+ inplanes,
29
+ inplanes,
30
+ kernel_size,
31
+ stride=stride,
32
+ padding=dilation,
33
+ dilation=dilation,
34
+ groups=inplanes,
35
+ bias=bias,
36
+ )
37
+ bn_depth = norm_layer(inplanes)
38
+ pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
39
+ bn_point = norm_layer(planes)
40
+
41
+ if relu_first:
42
+ self.block = nn.Sequential(
43
+ OrderedDict(
44
+ [
45
+ ("relu", nn.ReLU()),
46
+ ("depthwise", depthwise),
47
+ ("bn_depth", bn_depth),
48
+ ("pointwise", pointwise),
49
+ ("bn_point", bn_point),
50
+ ]
51
+ )
52
+ )
53
+ else:
54
+ self.block = nn.Sequential(
55
+ OrderedDict(
56
+ [
57
+ ("depthwise", depthwise),
58
+ ("bn_depth", bn_depth),
59
+ ("relu1", nn.ReLU(inplace=True)),
60
+ ("pointwise", pointwise),
61
+ ("bn_point", bn_point),
62
+ ("relu2", nn.ReLU(inplace=True)),
63
+ ]
64
+ )
65
+ )
66
+
67
+ def forward(self, x):
68
+ return self.block(x)
69
+
70
+
71
+ class _ConvBNReLU(nn.Module):
72
+ def __init__(
73
+ self,
74
+ in_channels,
75
+ out_channels,
76
+ kernel_size,
77
+ stride=1,
78
+ padding=0,
79
+ dilation=1,
80
+ groups=1,
81
+ relu6=False,
82
+ norm_layer=nn.BatchNorm2d,
83
+ ):
84
+ super(_ConvBNReLU, self).__init__()
85
+ self.conv = nn.Conv2d(
86
+ in_channels,
87
+ out_channels,
88
+ kernel_size,
89
+ stride,
90
+ padding,
91
+ dilation,
92
+ groups,
93
+ bias=False,
94
+ )
95
+ self.bn = norm_layer(out_channels)
96
+ self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
97
+
98
+ def forward(self, x):
99
+ x = self.conv(x)
100
+ x = self.bn(x)
101
+ x = self.relu(x)
102
+ return x
103
+
104
+
105
+ class _DepthwiseConv(nn.Module):
106
+ """conv_dw in MobileNet"""
107
+
108
+ def __init__(
109
+ self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs
110
+ ):
111
+ super(_DepthwiseConv, self).__init__()
112
+ self.conv = nn.Sequential(
113
+ _ConvBNReLU(
114
+ in_channels,
115
+ in_channels,
116
+ 3,
117
+ stride,
118
+ 1,
119
+ groups=in_channels,
120
+ norm_layer=norm_layer,
121
+ ),
122
+ _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer),
123
+ )
124
+
125
+ def forward(self, x):
126
+ return self.conv(x)
127
+
128
+
129
+ class InvertedResidual(nn.Module):
130
+ def __init__(
131
+ self,
132
+ in_channels,
133
+ out_channels,
134
+ stride,
135
+ expand_ratio,
136
+ dilation=1,
137
+ norm_layer=nn.BatchNorm2d,
138
+ ):
139
+ super(InvertedResidual, self).__init__()
140
+ assert stride in [1, 2]
141
+ self.use_res_connect = stride == 1 and in_channels == out_channels
142
+
143
+ layers = list()
144
+ inter_channels = int(round(in_channels * expand_ratio))
145
+ if expand_ratio != 1:
146
+ # pw
147
+ layers.append(
148
+ _ConvBNReLU(
149
+ in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer
150
+ )
151
+ )
152
+ layers.extend(
153
+ [
154
+ # dw
155
+ _ConvBNReLU(
156
+ inter_channels,
157
+ inter_channels,
158
+ 3,
159
+ stride,
160
+ dilation,
161
+ dilation,
162
+ groups=inter_channels,
163
+ relu6=True,
164
+ norm_layer=norm_layer,
165
+ ),
166
+ # pw-linear
167
+ nn.Conv2d(inter_channels, out_channels, 1, bias=False),
168
+ norm_layer(out_channels),
169
+ ]
170
+ )
171
+ self.conv = nn.Sequential(*layers)
172
+
173
+ def forward(self, x):
174
+ if self.use_res_connect:
175
+ return x + self.conv(x)
176
+ else:
177
+ return self.conv(x)
178
+
179
+
180
+ class MobileNetV2(nn.Module):
181
+ def __init__(self, norm_layer=nn.BatchNorm2d, pretrained_path=None, no_init=False):
182
+ super(MobileNetV2, self).__init__()
183
+ output_stride = 16
184
+ self.multiplier = 1.0
185
+ if output_stride == 32:
186
+ dilations = [1, 1]
187
+ elif output_stride == 16:
188
+ dilations = [1, 2]
189
+ elif output_stride == 8:
190
+ dilations = [2, 4]
191
+ else:
192
+ raise NotImplementedError
193
+ inverted_residual_setting = [
194
+ # t, c, n, s
195
+ [1, 16, 1, 1],
196
+ [6, 24, 2, 2],
197
+ [6, 32, 3, 2],
198
+ [6, 64, 4, 2],
199
+ [6, 96, 3, 1],
200
+ [6, 160, 3, 2],
201
+ [6, 320, 1, 1],
202
+ ]
203
+ # building first layer
204
+ input_channels = int(32 * self.multiplier) if self.multiplier > 1.0 else 32
205
+ # last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
206
+ self.conv1 = _ConvBNReLU(
207
+ 3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer
208
+ )
209
+
210
+ # building inverted residual blocks
211
+ self.planes = input_channels
212
+ self.block1 = self._make_layer(
213
+ InvertedResidual,
214
+ self.planes,
215
+ inverted_residual_setting[0:1],
216
+ norm_layer=norm_layer,
217
+ )
218
+ self.block2 = self._make_layer(
219
+ InvertedResidual,
220
+ self.planes,
221
+ inverted_residual_setting[1:2],
222
+ norm_layer=norm_layer,
223
+ )
224
+ self.block3 = self._make_layer(
225
+ InvertedResidual,
226
+ self.planes,
227
+ inverted_residual_setting[2:3],
228
+ norm_layer=norm_layer,
229
+ )
230
+ self.block4 = self._make_layer(
231
+ InvertedResidual,
232
+ self.planes,
233
+ inverted_residual_setting[3:5],
234
+ dilations[0],
235
+ norm_layer=norm_layer,
236
+ )
237
+ self.block5 = self._make_layer(
238
+ InvertedResidual,
239
+ self.planes,
240
+ inverted_residual_setting[5:],
241
+ dilations[1],
242
+ norm_layer=norm_layer,
243
+ )
244
+ self.last_inp_channels = self.planes
245
+
246
+ self.up2 = InterpolateNearest2d()
247
+
248
+ # weight initialization
249
+ if not no_init:
250
+ self.pretrained_path = pretrained_path
251
+ if pretrained_path is not None:
252
+ self._load_pretrained_model()
253
+ else:
254
+ for m in self.modules():
255
+ if isinstance(m, nn.Conv2d):
256
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
257
+ if m.bias is not None:
258
+ nn.init.zeros_(m.bias)
259
+ elif isinstance(m, nn.BatchNorm2d):
260
+ nn.init.ones_(m.weight)
261
+ nn.init.zeros_(m.bias)
262
+ elif isinstance(m, nn.Linear):
263
+ nn.init.normal_(m.weight, 0, 0.01)
264
+ if m.bias is not None:
265
+ nn.init.zeros_(m.bias)
266
+
267
+ def _make_layer(
268
+ self,
269
+ block,
270
+ planes,
271
+ inverted_residual_setting,
272
+ dilation=1,
273
+ norm_layer=nn.BatchNorm2d,
274
+ ):
275
+ features = list()
276
+ for t, c, n, s in inverted_residual_setting:
277
+ out_channels = int(c * self.multiplier)
278
+ stride = s if dilation == 1 else 1
279
+ features.append(
280
+ block(planes, out_channels, stride, t, dilation, norm_layer)
281
+ )
282
+ planes = out_channels
283
+ for i in range(n - 1):
284
+ features.append(
285
+ block(planes, out_channels, 1, t, norm_layer=norm_layer)
286
+ )
287
+ planes = out_channels
288
+ self.planes = planes
289
+ return nn.Sequential(*features)
290
+
291
+ def forward(self, x):
292
+ x = self.conv1(x)
293
+ x = self.block1(x)
294
+ c1 = self.block2(x)
295
+ c2 = self.block3(c1)
296
+ c3 = self.block4(c2)
297
+ c4 = self.up2(self.block5(c3))
298
+
299
+ # x = self.features(x)
300
+ # x = self.classifier(x.view(x.size(0), x.size(1)))
301
+ return c4, c1
302
+
303
+ def _load_pretrained_model(self):
304
+ assert self.pretrained_path is not None
305
+ assert Path(self.pretrained_path).exists()
306
+
307
+ pretrain_dict = torch.load(self.pretrained_path)
308
+ pretrain_dict = {k.replace("encoder.", ""): v for k, v in pretrain_dict.items()}
309
+ model_dict = {}
310
+ state_dict = self.state_dict()
311
+ ignored = []
312
+ for k, v in pretrain_dict.items():
313
+ if k in state_dict:
314
+ model_dict[k] = v
315
+ else:
316
+ ignored.append(k)
317
+ state_dict.update(model_dict)
318
+ self.load_state_dict(state_dict)
319
+ self.loaded_pre_trained = True
320
+ print(
321
+ " - Loaded pre-trained MobileNetV2: ignored {}/{} keys".format(
322
+ len(ignored), len(pretrain_dict)
323
+ )
324
+ )
climategan/deeplab/resnet101_v3.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class Bottleneck(nn.Module):
5
+ expansion = 4
6
+
7
+ def __init__(
8
+ self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None
9
+ ):
10
+ super(Bottleneck, self).__init__()
11
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
12
+ self.bn1 = BatchNorm(planes)
13
+ self.conv2 = nn.Conv2d(
14
+ planes,
15
+ planes,
16
+ kernel_size=3,
17
+ stride=stride,
18
+ dilation=dilation,
19
+ padding=dilation,
20
+ bias=False,
21
+ )
22
+ self.bn2 = BatchNorm(planes)
23
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
24
+ self.bn3 = BatchNorm(planes * 4)
25
+ self.relu = nn.ReLU(inplace=True)
26
+ self.downsample = downsample
27
+ self.stride = stride
28
+ self.dilation = dilation
29
+
30
+ def forward(self, x):
31
+ residual = x
32
+
33
+ out = self.conv1(x)
34
+ out = self.bn1(out)
35
+ out = self.relu(out)
36
+
37
+ out = self.conv2(out)
38
+ out = self.bn2(out)
39
+ out = self.relu(out)
40
+
41
+ out = self.conv3(out)
42
+ out = self.bn3(out)
43
+
44
+ if self.downsample is not None:
45
+ residual = self.downsample(x)
46
+
47
+ out += residual
48
+ out = self.relu(out)
49
+
50
+ return out
51
+
52
+
53
+ class ResNet(nn.Module):
54
+ def __init__(
55
+ self, block, layers, output_stride, BatchNorm, verbose=0, no_init=False
56
+ ):
57
+ self.inplanes = 64
58
+ self.verbose = verbose
59
+ super(ResNet, self).__init__()
60
+ blocks = [1, 2, 4]
61
+ if output_stride == 16:
62
+ strides = [1, 2, 2, 1]
63
+ dilations = [1, 1, 1, 2]
64
+ elif output_stride == 8:
65
+ strides = [1, 2, 1, 1]
66
+ dilations = [1, 1, 2, 4]
67
+ else:
68
+ raise NotImplementedError
69
+
70
+ # Modules
71
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
72
+ self.bn1 = BatchNorm(64)
73
+ self.relu = nn.ReLU(inplace=True)
74
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
75
+
76
+ self.layer1 = self._make_layer(
77
+ block,
78
+ 64,
79
+ layers[0],
80
+ stride=strides[0],
81
+ dilation=dilations[0],
82
+ BatchNorm=BatchNorm,
83
+ )
84
+ self.layer2 = self._make_layer(
85
+ block,
86
+ 128,
87
+ layers[1],
88
+ stride=strides[1],
89
+ dilation=dilations[1],
90
+ BatchNorm=BatchNorm,
91
+ )
92
+ self.layer3 = self._make_layer(
93
+ block,
94
+ 256,
95
+ layers[2],
96
+ stride=strides[2],
97
+ dilation=dilations[2],
98
+ BatchNorm=BatchNorm,
99
+ )
100
+ self.layer4 = self._make_MG_unit(
101
+ block,
102
+ 512,
103
+ blocks=blocks,
104
+ stride=strides[3],
105
+ dilation=dilations[3],
106
+ BatchNorm=BatchNorm,
107
+ )
108
+
109
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
110
+ downsample = None
111
+ if stride != 1 or self.inplanes != planes * block.expansion:
112
+ downsample = nn.Sequential(
113
+ nn.Conv2d(
114
+ self.inplanes,
115
+ planes * block.expansion,
116
+ kernel_size=1,
117
+ stride=stride,
118
+ bias=False,
119
+ ),
120
+ BatchNorm(planes * block.expansion),
121
+ )
122
+
123
+ layers = []
124
+ layers.append(
125
+ block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)
126
+ )
127
+ self.inplanes = planes * block.expansion
128
+ for i in range(1, blocks):
129
+ layers.append(
130
+ block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)
131
+ )
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def _make_MG_unit(
136
+ self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None
137
+ ):
138
+ downsample = None
139
+ if stride != 1 or self.inplanes != planes * block.expansion:
140
+ downsample = nn.Sequential(
141
+ nn.Conv2d(
142
+ self.inplanes,
143
+ planes * block.expansion,
144
+ kernel_size=1,
145
+ stride=stride,
146
+ bias=False,
147
+ ),
148
+ BatchNorm(planes * block.expansion),
149
+ )
150
+
151
+ layers = []
152
+ layers.append(
153
+ block(
154
+ self.inplanes,
155
+ planes,
156
+ stride,
157
+ dilation=blocks[0] * dilation,
158
+ downsample=downsample,
159
+ BatchNorm=BatchNorm,
160
+ )
161
+ )
162
+ self.inplanes = planes * block.expansion
163
+ for i in range(1, len(blocks)):
164
+ layers.append(
165
+ block(
166
+ self.inplanes,
167
+ planes,
168
+ stride=1,
169
+ dilation=blocks[i] * dilation,
170
+ BatchNorm=BatchNorm,
171
+ )
172
+ )
173
+
174
+ return nn.Sequential(*layers)
175
+
176
+ def forward(self, input):
177
+ x = self.conv1(input)
178
+ x = self.bn1(x)
179
+ x = self.relu(x)
180
+ x = self.maxpool(x)
181
+
182
+ x = self.layer1(x)
183
+ low_level_feat = x
184
+ x = self.layer2(x)
185
+ x = self.layer3(x)
186
+ x = self.layer4(x)
187
+ return x, low_level_feat
188
+
189
+
190
+ def ResNet101(output_stride=8, BatchNorm=nn.BatchNorm2d, verbose=0, no_init=False):
191
+ """Constructs a ResNet-101 model.
192
+ Args:
193
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
194
+ """
195
+ model = ResNet(
196
+ Bottleneck,
197
+ [3, 4, 23, 3],
198
+ output_stride,
199
+ BatchNorm,
200
+ verbose=verbose,
201
+ no_init=no_init,
202
+ )
203
+ return model
climategan/deeplab/resnetmulti_v2.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from climategan.blocks import ResBlocks
3
+
4
+ affine_par = True
5
+
6
+
7
+ class Bottleneck(nn.Module):
8
+ expansion = 4
9
+
10
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
11
+ super(Bottleneck, self).__init__()
12
+ # change
13
+ self.conv1 = nn.Conv2d(
14
+ inplanes, planes, kernel_size=1, stride=stride, bias=False
15
+ )
16
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
17
+ for i in self.bn1.parameters():
18
+ i.requires_grad = False
19
+ padding = dilation
20
+ # change
21
+ self.conv2 = nn.Conv2d(
22
+ planes,
23
+ planes,
24
+ kernel_size=3,
25
+ stride=1,
26
+ padding=padding,
27
+ bias=False,
28
+ dilation=dilation,
29
+ )
30
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
31
+ for i in self.bn2.parameters():
32
+ i.requires_grad = False
33
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
34
+ self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
35
+ for i in self.bn3.parameters():
36
+ i.requires_grad = False
37
+ self.relu = nn.ReLU(inplace=True)
38
+ self.downsample = downsample
39
+ self.stride = stride
40
+
41
+ def forward(self, x):
42
+ residual = x
43
+ out = self.conv1(x)
44
+ out = self.bn1(out)
45
+ out = self.relu(out)
46
+ out = self.conv2(out)
47
+ out = self.bn2(out)
48
+ out = self.relu(out)
49
+ out = self.conv3(out)
50
+ out = self.bn3(out)
51
+ if self.downsample is not None:
52
+ residual = self.downsample(x)
53
+ out += residual
54
+ out = self.relu(out)
55
+
56
+ return out
57
+
58
+
59
+ class ResNetMulti(nn.Module):
60
+ def __init__(
61
+ self,
62
+ layers,
63
+ n_res=4,
64
+ res_norm="instance",
65
+ activ="lrelu",
66
+ pad_type="reflect",
67
+ ):
68
+ self.inplanes = 64
69
+ block = Bottleneck
70
+ super(ResNetMulti, self).__init__()
71
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
72
+ self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
73
+ for i in self.bn1.parameters():
74
+ i.requires_grad = False
75
+ self.relu = nn.ReLU(inplace=True)
76
+ self.maxpool = nn.MaxPool2d(
77
+ kernel_size=3, stride=2, padding=0, ceil_mode=True
78
+ ) # changed padding from 1 to 0
79
+ self.layer1 = self._make_layer(block, 64, layers[0])
80
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
81
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
82
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
83
+
84
+ for m in self.modules():
85
+ if isinstance(m, nn.Conv2d):
86
+ m.weight.data.normal_(0, 0.01)
87
+ elif isinstance(m, nn.BatchNorm2d):
88
+ m.weight.data.fill_(1)
89
+ m.bias.data.zero_()
90
+ self.layer_res = ResBlocks(
91
+ n_res, 2048, norm=res_norm, activation=activ, pad_type=pad_type
92
+ )
93
+
94
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
95
+ downsample = None
96
+ if (
97
+ stride != 1
98
+ or self.inplanes != planes * block.expansion
99
+ or dilation == 2
100
+ or dilation == 4
101
+ ):
102
+ downsample = nn.Sequential(
103
+ nn.Conv2d(
104
+ self.inplanes,
105
+ planes * block.expansion,
106
+ kernel_size=1,
107
+ stride=stride,
108
+ bias=False,
109
+ ),
110
+ nn.BatchNorm2d(planes * block.expansion, affine=affine_par),
111
+ )
112
+ for i in downsample._modules["1"].parameters():
113
+ i.requires_grad = False
114
+ layers = []
115
+ layers.append(
116
+ block(
117
+ self.inplanes, planes, stride, dilation=dilation, downsample=downsample
118
+ )
119
+ )
120
+ self.inplanes = planes * block.expansion
121
+ for i in range(1, blocks):
122
+ layers.append(block(self.inplanes, planes, dilation=dilation))
123
+
124
+ return nn.Sequential(*layers)
125
+
126
+ def forward(self, x):
127
+ x = self.conv1(x)
128
+ x = self.bn1(x)
129
+ x = self.relu(x)
130
+ x = self.maxpool(x)
131
+ x = self.layer1(x)
132
+ x = self.layer2(x)
133
+ x = self.layer3(x)
134
+ x = self.layer4(x)
135
+ x = self.layer_res(x)
136
+ return x
climategan/depth.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from climategan.blocks import BaseDecoder, Conv2dBlock, InterpolateNearest2d
6
+ from climategan.utils import find_target_size
7
+
8
+
9
+ def create_depth_decoder(opts, no_init=False, verbose=0):
10
+ if opts.gen.d.architecture == "base":
11
+ decoder = BaseDepthDecoder(opts)
12
+ if "s" in opts.task:
13
+ assert opts.gen.s.use_dada is False
14
+ if "m" in opts.tasks:
15
+ assert opts.gen.m.use_dada is False
16
+ else:
17
+ decoder = DADADepthDecoder(opts)
18
+
19
+ if verbose > 0:
20
+ print(f" - Add {decoder.__class__.__name__}")
21
+
22
+ return decoder
23
+
24
+
25
+ class DADADepthDecoder(nn.Module):
26
+ """
27
+ Depth decoder based on depth auxiliary task in DADA paper
28
+ """
29
+
30
+ def __init__(self, opts):
31
+ super().__init__()
32
+ if (
33
+ opts.gen.encoder.architecture == "deeplabv3"
34
+ and opts.gen.deeplabv3.backbone == "mobilenet"
35
+ ):
36
+ res_dim = 320
37
+ else:
38
+ res_dim = 2048
39
+
40
+ mid_dim = 512
41
+
42
+ self.do_feat_fusion = False
43
+ if opts.gen.m.use_dada or ("s" in opts.tasks and opts.gen.s.use_dada):
44
+ self.do_feat_fusion = True
45
+ self.dec4 = Conv2dBlock(
46
+ 128,
47
+ res_dim,
48
+ 1,
49
+ stride=1,
50
+ padding=0,
51
+ bias=True,
52
+ activation="lrelu",
53
+ norm="none",
54
+ )
55
+
56
+ self.relu = nn.ReLU(inplace=True)
57
+ self.enc4_1 = Conv2dBlock(
58
+ res_dim,
59
+ mid_dim,
60
+ 1,
61
+ stride=1,
62
+ padding=0,
63
+ bias=False,
64
+ activation="lrelu",
65
+ pad_type="reflect",
66
+ norm="batch",
67
+ )
68
+ self.enc4_2 = Conv2dBlock(
69
+ mid_dim,
70
+ mid_dim,
71
+ 3,
72
+ stride=1,
73
+ padding=1,
74
+ bias=False,
75
+ activation="lrelu",
76
+ pad_type="reflect",
77
+ norm="batch",
78
+ )
79
+ self.enc4_3 = Conv2dBlock(
80
+ mid_dim,
81
+ 128,
82
+ 1,
83
+ stride=1,
84
+ padding=0,
85
+ bias=False,
86
+ activation="lrelu",
87
+ pad_type="reflect",
88
+ norm="batch",
89
+ )
90
+ self.upsample = None
91
+ if opts.gen.d.upsample_featuremaps:
92
+ self.upsample = nn.Sequential(
93
+ *[
94
+ InterpolateNearest2d(),
95
+ Conv2dBlock(
96
+ 128,
97
+ 32,
98
+ 3,
99
+ stride=1,
100
+ padding=1,
101
+ bias=False,
102
+ activation="lrelu",
103
+ pad_type="reflect",
104
+ norm="batch",
105
+ ),
106
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
107
+ ]
108
+ )
109
+ self._target_size = find_target_size(opts, "d")
110
+ print(
111
+ " - {}: setting target size to {}".format(
112
+ self.__class__.__name__, self._target_size
113
+ )
114
+ )
115
+
116
+ def set_target_size(self, size):
117
+ """
118
+ Set final interpolation's target size
119
+
120
+ Args:
121
+ size (int, list, tuple): target size (h, w). If int, target will be (i, i)
122
+ """
123
+ if isinstance(size, (list, tuple)):
124
+ self._target_size = size[:2]
125
+ else:
126
+ self._target_size = (size, size)
127
+
128
+ def forward(self, z):
129
+ if isinstance(z, (list, tuple)):
130
+ z = z[0]
131
+ z4_enc = self.enc4_1(z)
132
+ z4_enc = self.enc4_2(z4_enc)
133
+ z4_enc = self.enc4_3(z4_enc)
134
+
135
+ z_depth = None
136
+ if self.do_feat_fusion:
137
+ z_depth = self.dec4(z4_enc)
138
+
139
+ if self.upsample is not None:
140
+ z4_enc = self.upsample(z4_enc)
141
+
142
+ depth = torch.mean(z4_enc, dim=1, keepdim=True) # DADA paper decoder
143
+ if depth.shape[-1] != self._target_size:
144
+ depth = F.interpolate(
145
+ depth,
146
+ size=(384, 384), # size used in MiDaS inference
147
+ mode="bicubic", # what MiDaS uses
148
+ align_corners=False,
149
+ )
150
+
151
+ depth = F.interpolate(
152
+ depth, (self._target_size, self._target_size), mode="nearest"
153
+ ) # what we used in the transforms to resize input
154
+
155
+ return depth, z_depth
156
+
157
+ def __str__(self):
158
+ return "DADA Depth Decoder"
159
+
160
+
161
+ class BaseDepthDecoder(BaseDecoder):
162
+ def __init__(self, opts):
163
+ low_level_feats_dim = -1
164
+ use_v3 = opts.gen.encoder.architecture == "deeplabv3"
165
+ use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet"
166
+ use_low = opts.gen.d.use_low_level_feats
167
+
168
+ if use_v3 and use_mobile_net:
169
+ input_dim = 320
170
+ if use_low:
171
+ low_level_feats_dim = 24
172
+ elif use_v3:
173
+ input_dim = 2048
174
+ if use_low:
175
+ low_level_feats_dim = 256
176
+ else:
177
+ input_dim = 2048
178
+
179
+ n_upsample = 1 if opts.gen.d.upsample_featuremaps else 0
180
+ output_dim = (
181
+ 1
182
+ if not opts.gen.d.classify.enable
183
+ else opts.gen.d.classify.linspace.buckets
184
+ )
185
+
186
+ self._target_size = find_target_size(opts, "d")
187
+ print(
188
+ " - {}: setting target size to {}".format(
189
+ self.__class__.__name__, self._target_size
190
+ )
191
+ )
192
+
193
+ super().__init__(
194
+ n_upsample=n_upsample,
195
+ n_res=opts.gen.d.n_res,
196
+ input_dim=input_dim,
197
+ proj_dim=opts.gen.d.proj_dim,
198
+ output_dim=output_dim,
199
+ norm=opts.gen.d.norm,
200
+ activ=opts.gen.d.activ,
201
+ pad_type=opts.gen.d.pad_type,
202
+ output_activ="none",
203
+ low_level_feats_dim=low_level_feats_dim,
204
+ )
205
+
206
+ def set_target_size(self, size):
207
+ """
208
+ Set final interpolation's target size
209
+
210
+ Args:
211
+ size (int, list, tuple): target size (h, w). If int, target will be (i, i)
212
+ """
213
+ if isinstance(size, (list, tuple)):
214
+ self._target_size = size[:2]
215
+ else:
216
+ self._target_size = (size, size)
217
+
218
+ def forward(self, z, cond=None):
219
+ if self._target_size is None:
220
+ error = "self._target_size should be set with self.set_target_size()"
221
+ error += "to interpolate depth to the target depth map's size"
222
+ raise ValueError(error)
223
+
224
+ d = super().forward(z)
225
+
226
+ preds = F.interpolate(
227
+ d, size=self._target_size, mode="bilinear", align_corners=True
228
+ )
229
+
230
+ return preds, None
climategan/discriminator.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Discriminator architecture for ClimateGAN's GAN components (a and t)
2
+ """
3
+ import functools
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from climategan.blocks import SpectralNorm
9
+ from climategan.tutils import init_weights
10
+
11
+ # from torch.optim import lr_scheduler
12
+
13
+ # mainly from https://github.com/sangwoomo/instagan/blob/master/models/networks.py
14
+
15
+
16
+ def create_discriminator(opts, device, no_init=False, verbose=0):
17
+ disc = OmniDiscriminator(opts)
18
+ if no_init:
19
+ return disc
20
+
21
+ for task, model in disc.items():
22
+ if isinstance(model, nn.ModuleDict):
23
+ for domain, domain_model in model.items():
24
+ init_weights(
25
+ domain_model,
26
+ init_type=opts.dis[task].init_type,
27
+ init_gain=opts.dis[task].init_gain,
28
+ verbose=verbose,
29
+ caller=f"create_discriminator {task} {domain}",
30
+ )
31
+ else:
32
+ init_weights(
33
+ model,
34
+ init_type=opts.dis[task].init_type,
35
+ init_gain=opts.dis[task].init_gain,
36
+ verbose=verbose,
37
+ caller=f"create_discriminator {task}",
38
+ )
39
+ return disc.to(device)
40
+
41
+
42
+ def define_D(
43
+ input_nc,
44
+ ndf,
45
+ n_layers=3,
46
+ norm="batch",
47
+ use_sigmoid=False,
48
+ get_intermediate_features=False,
49
+ num_D=1,
50
+ ):
51
+ norm_layer = get_norm_layer(norm_type=norm)
52
+ net = MultiscaleDiscriminator(
53
+ input_nc,
54
+ ndf,
55
+ n_layers=n_layers,
56
+ norm_layer=norm_layer,
57
+ use_sigmoid=use_sigmoid,
58
+ get_intermediate_features=get_intermediate_features,
59
+ num_D=num_D,
60
+ )
61
+ return net
62
+
63
+
64
+ def get_norm_layer(norm_type="instance"):
65
+ if not norm_type:
66
+ print("norm_type is {}, defaulting to instance")
67
+ norm_type = "instance"
68
+ if norm_type == "batch":
69
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
70
+ elif norm_type == "instance":
71
+ norm_layer = functools.partial(
72
+ nn.InstanceNorm2d, affine=False, track_running_stats=False
73
+ )
74
+ elif norm_type == "none":
75
+ norm_layer = None
76
+ else:
77
+ raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
78
+ return norm_layer
79
+
80
+
81
+ # Defines the PatchGAN discriminator with the specified arguments.
82
+ class NLayerDiscriminator(nn.Module):
83
+ def __init__(
84
+ self,
85
+ input_nc=3,
86
+ ndf=64,
87
+ n_layers=3,
88
+ norm_layer=nn.BatchNorm2d,
89
+ use_sigmoid=False,
90
+ get_intermediate_features=True,
91
+ ):
92
+ super(NLayerDiscriminator, self).__init__()
93
+ if type(norm_layer) == functools.partial:
94
+ use_bias = norm_layer.func == nn.InstanceNorm2d
95
+ else:
96
+ use_bias = norm_layer == nn.InstanceNorm2d
97
+
98
+ self.get_intermediate_features = get_intermediate_features
99
+
100
+ kw = 4
101
+ padw = 1
102
+ sequence = [
103
+ [
104
+ # Use spectral normalization
105
+ SpectralNorm(
106
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)
107
+ ),
108
+ nn.LeakyReLU(0.2, True),
109
+ ]
110
+ ]
111
+
112
+ nf_mult = 1
113
+ nf_mult_prev = 1
114
+ for n in range(1, n_layers):
115
+ nf_mult_prev = nf_mult
116
+ nf_mult = min(2 ** n, 8)
117
+ sequence += [
118
+ [
119
+ # Use spectral normalization
120
+ SpectralNorm( # TODO replace with Conv2dBlock
121
+ nn.Conv2d(
122
+ ndf * nf_mult_prev,
123
+ ndf * nf_mult,
124
+ kernel_size=kw,
125
+ stride=2,
126
+ padding=padw,
127
+ bias=use_bias,
128
+ )
129
+ ),
130
+ norm_layer(ndf * nf_mult),
131
+ nn.LeakyReLU(0.2, True),
132
+ ]
133
+ ]
134
+
135
+ nf_mult_prev = nf_mult
136
+ nf_mult = min(2 ** n_layers, 8)
137
+ sequence += [
138
+ [
139
+ # Use spectral normalization
140
+ SpectralNorm(
141
+ nn.Conv2d(
142
+ ndf * nf_mult_prev,
143
+ ndf * nf_mult,
144
+ kernel_size=kw,
145
+ stride=1,
146
+ padding=padw,
147
+ bias=use_bias,
148
+ )
149
+ ),
150
+ norm_layer(ndf * nf_mult),
151
+ nn.LeakyReLU(0.2, True),
152
+ ]
153
+ ]
154
+
155
+ # Use spectral normalization
156
+ sequence += [
157
+ [
158
+ SpectralNorm(
159
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
160
+ )
161
+ ]
162
+ ]
163
+
164
+ if use_sigmoid:
165
+ sequence += [[nn.Sigmoid()]]
166
+
167
+ # We divide the layers into groups to extract intermediate layer outputs
168
+ for n in range(len(sequence)):
169
+ self.add_module("model" + str(n), nn.Sequential(*sequence[n]))
170
+ # self.model = nn.Sequential(*sequence)
171
+
172
+ def forward(self, input):
173
+ results = [input]
174
+ for submodel in self.children():
175
+ intermediate_output = submodel(results[-1])
176
+ results.append(intermediate_output)
177
+
178
+ get_intermediate_features = self.get_intermediate_features
179
+ if get_intermediate_features:
180
+ return results[1:]
181
+ else:
182
+ return results[-1]
183
+
184
+
185
+ # def forward(self, input):
186
+ # return self.model(input)
187
+
188
+
189
+ # Source: https://github.com/NVIDIA/pix2pixHD
190
+ class MultiscaleDiscriminator(nn.Module):
191
+ def __init__(
192
+ self,
193
+ input_nc=3,
194
+ ndf=64,
195
+ n_layers=3,
196
+ norm_layer=nn.BatchNorm2d,
197
+ use_sigmoid=False,
198
+ get_intermediate_features=True,
199
+ num_D=3,
200
+ ):
201
+ super(MultiscaleDiscriminator, self).__init__()
202
+ # self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
203
+ # use_sigmoid=False, num_D=3, getIntermFeat=False
204
+
205
+ self.n_layers = n_layers
206
+ self.ndf = ndf
207
+ self.norm_layer = norm_layer
208
+ self.use_sigmoid = use_sigmoid
209
+ self.get_intermediate_features = get_intermediate_features
210
+ self.num_D = num_D
211
+
212
+ for i in range(self.num_D):
213
+ netD = NLayerDiscriminator(
214
+ input_nc=input_nc,
215
+ ndf=self.ndf,
216
+ n_layers=self.n_layers,
217
+ norm_layer=self.norm_layer,
218
+ use_sigmoid=self.use_sigmoid,
219
+ get_intermediate_features=self.get_intermediate_features,
220
+ )
221
+ self.add_module("discriminator_%d" % i, netD)
222
+
223
+ self.downsample = nn.AvgPool2d(
224
+ 3, stride=2, padding=[1, 1], count_include_pad=False
225
+ )
226
+
227
+ def forward(self, input):
228
+ result = []
229
+ get_intermediate_features = self.get_intermediate_features
230
+ for name, D in self.named_children():
231
+ if "discriminator" not in name:
232
+ continue
233
+ out = D(input)
234
+ if not get_intermediate_features:
235
+ out = [out]
236
+ result.append(out)
237
+ input = self.downsample(input)
238
+
239
+ return result
240
+
241
+
242
+ class OmniDiscriminator(nn.ModuleDict):
243
+ def __init__(self, opts):
244
+ super().__init__()
245
+ if "p" in opts.tasks:
246
+ if opts.dis.p.use_local_discriminator:
247
+
248
+ self["p"] = nn.ModuleDict(
249
+ {
250
+ "global": define_D(
251
+ input_nc=3,
252
+ ndf=opts.dis.p.ndf,
253
+ n_layers=opts.dis.p.n_layers,
254
+ norm=opts.dis.p.norm,
255
+ use_sigmoid=opts.dis.p.use_sigmoid,
256
+ get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
257
+ num_D=opts.dis.p.num_D,
258
+ ),
259
+ "local": define_D(
260
+ input_nc=3,
261
+ ndf=opts.dis.p.ndf,
262
+ n_layers=opts.dis.p.n_layers,
263
+ norm=opts.dis.p.norm,
264
+ use_sigmoid=opts.dis.p.use_sigmoid,
265
+ get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
266
+ num_D=opts.dis.p.num_D,
267
+ ),
268
+ }
269
+ )
270
+ else:
271
+ self["p"] = define_D(
272
+ input_nc=4, # image + mask
273
+ ndf=opts.dis.p.ndf,
274
+ n_layers=opts.dis.p.n_layers,
275
+ norm=opts.dis.p.norm,
276
+ use_sigmoid=opts.dis.p.use_sigmoid,
277
+ get_intermediate_features=opts.dis.p.get_intermediate_features,
278
+ num_D=opts.dis.p.num_D,
279
+ )
280
+ if "m" in opts.tasks:
281
+ if opts.gen.m.use_advent:
282
+ if opts.dis.m.architecture == "base":
283
+ if opts.dis.m.gan_type == "WGAN_norm":
284
+ self["m"] = nn.ModuleDict(
285
+ {
286
+ "Advent": get_fc_discriminator(
287
+ num_classes=2, use_norm=True
288
+ )
289
+ }
290
+ )
291
+ else:
292
+ self["m"] = nn.ModuleDict(
293
+ {
294
+ "Advent": get_fc_discriminator(
295
+ num_classes=2, use_norm=False
296
+ )
297
+ }
298
+ )
299
+ elif opts.dis.m.architecture == "OmniDiscriminator":
300
+ self["m"] = nn.ModuleDict(
301
+ {
302
+ "Advent": define_D(
303
+ input_nc=2,
304
+ ndf=opts.dis.m.ndf,
305
+ n_layers=opts.dis.m.n_layers,
306
+ norm=opts.dis.m.norm,
307
+ use_sigmoid=opts.dis.m.use_sigmoid,
308
+ get_intermediate_features=opts.dis.m.get_intermediate_features, # noqa: E501
309
+ num_D=opts.dis.m.num_D,
310
+ )
311
+ }
312
+ )
313
+ else:
314
+ raise Exception("This Discriminator is currently not supported!")
315
+ if "s" in opts.tasks:
316
+ if opts.gen.s.use_advent:
317
+ if opts.dis.s.gan_type == "WGAN_norm":
318
+ self["s"] = nn.ModuleDict(
319
+ {"Advent": get_fc_discriminator(num_classes=11, use_norm=True)}
320
+ )
321
+ else:
322
+ self["s"] = nn.ModuleDict(
323
+ {"Advent": get_fc_discriminator(num_classes=11, use_norm=False)}
324
+ )
325
+
326
+
327
+ def get_fc_discriminator(num_classes=2, ndf=64, use_norm=False):
328
+ if use_norm:
329
+ return torch.nn.Sequential(
330
+ SpectralNorm(
331
+ torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
332
+ ),
333
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
334
+ SpectralNorm(
335
+ torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
336
+ ),
337
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
338
+ SpectralNorm(
339
+ torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
340
+ ),
341
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
342
+ SpectralNorm(
343
+ torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
344
+ ),
345
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
346
+ SpectralNorm(
347
+ torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1)
348
+ ),
349
+ )
350
+ else:
351
+ return torch.nn.Sequential(
352
+ torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1),
353
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
354
+ torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
355
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
356
+ torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
357
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
358
+ torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
359
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
360
+ torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1),
361
+ )
climategan/eval_metrics.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from skimage import filters
5
+ from sklearn.metrics.pairwise import euclidean_distances
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ from copy import deepcopy
9
+
10
+ # ------------------------------------------------------------------------------
11
+ # ----- Evaluation metrics for a pair of binary mask images (pred, target) -----
12
+ # ------------------------------------------------------------------------------
13
+
14
+
15
+ def get_accuracy(arr1, arr2):
16
+ """pixel accuracy
17
+
18
+ Args:
19
+ arr1 (np.array)
20
+ arr2 (np.array)
21
+ """
22
+ return (arr1 == arr2).sum() / arr1.size
23
+
24
+
25
+ def trimap(pred_im, gt_im, thickness=8):
26
+ """Compute accuracy in a region of thickness around the contours
27
+ for binary images (0-1 values)
28
+ Args:
29
+ pred_im (Image): Prediction
30
+ gt_im (Image): Target
31
+ thickness (int, optional): [description]. Defaults to 8.
32
+ """
33
+ W, H = gt_im.size
34
+ contours, hierarchy = cv2.findContours(
35
+ np.array(gt_im), mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE
36
+ )
37
+ mask_contour = np.zeros((H, W), dtype=np.int32)
38
+ cv2.drawContours(
39
+ mask_contour, contours, -1, (1), thickness=thickness, hierarchy=hierarchy
40
+ )
41
+ gt_contour = np.array(gt_im)[np.where(mask_contour > 0)]
42
+ pred_contour = np.array(pred_im)[np.where(mask_contour > 0)]
43
+ return get_accuracy(pred_contour, gt_contour)
44
+
45
+
46
+ def iou(pred_im, gt_im):
47
+ """
48
+ IoU for binary masks (0-1 values)
49
+
50
+ Args:
51
+ pred_im ([type]): [description]
52
+ gt_im ([type]): [description]
53
+ """
54
+ pred = np.array(pred_im)
55
+ gt = np.array(gt_im)
56
+ intersection = (pred * gt).sum()
57
+ union = (pred + gt).sum() - intersection
58
+ return intersection / union
59
+
60
+
61
+ def f1_score(pred_im, gt_im):
62
+ pred = np.array(pred_im)
63
+ gt = np.array(gt_im)
64
+ intersection = (pred * gt).sum()
65
+ return 2 * intersection / (pred + gt).sum()
66
+
67
+
68
+ def accuracy(pred_im, gt_im):
69
+ pred = np.array(pred_im)
70
+ gt = np.array(gt_im)
71
+ if len(gt_im.shape) == 4:
72
+ assert gt_im.shape[1] == 1
73
+ gt_im = gt_im[:, 0, :, :]
74
+ if len(pred.shape) > len(gt_im.shape):
75
+ pred = np.argmax(pred, axis=1)
76
+ return float((pred == gt).sum()) / gt.size
77
+
78
+
79
+ def mIOU(pred, label, average="macro"):
80
+ """
81
+ Adapted from:
82
+ https://stackoverflow.com/questions/62461379/multiclass-semantic-segmentation-model-evaluation
83
+
84
+ Compute the mean IOU from pred and label tensors
85
+ pred is a tensor N x C x H x W with logits (softmax will be applied)
86
+ and label is a N x H x W tensor with int labels per pixel
87
+
88
+ this does the same as sklearn's jaccard_score function if you choose average="macro"
89
+ Args:
90
+ pred (torch.tensor): predicted logits
91
+ label (torch.tensor): labels
92
+ average: "macro" or "weighted"
93
+
94
+ Returns:
95
+ float: mIOU, can be nan
96
+ """
97
+ num_classes = pred.shape[-3]
98
+
99
+ pred = torch.argmax(pred, dim=1).squeeze(1)
100
+ present_iou_list = list()
101
+ pred = pred.view(-1)
102
+ label = label.view(-1)
103
+ # Note: Following for loop goes from 0 to (num_classes-1)
104
+ # and ignore_index is num_classes, thus ignore_index is
105
+ # not considered in computation of IoU.
106
+ interesting_classes = (
107
+ [*range(num_classes)] if num_classes > 2 else [int(label.max().item())]
108
+ )
109
+ weights = []
110
+
111
+ for sem_class in interesting_classes:
112
+ pred_inds = pred == sem_class
113
+ target_inds = label == sem_class
114
+ if (target_inds.long().sum().item() > 0) or (pred_inds.long().sum().item() > 0):
115
+ intersection_now = (pred_inds[target_inds]).long().sum().item()
116
+ union_now = (
117
+ pred_inds.long().sum().item()
118
+ + target_inds.long().sum().item()
119
+ - intersection_now
120
+ )
121
+ weights.append(pred_inds.long().sum().item())
122
+ iou_now = float(intersection_now) / float(union_now)
123
+ present_iou_list.append(iou_now)
124
+ if not present_iou_list:
125
+ return float("nan")
126
+ elif average == "weighted":
127
+ weighted_avg = np.sum(np.multiply(weights, present_iou_list) / np.sum(weights))
128
+ return weighted_avg
129
+ else:
130
+ return np.mean(present_iou_list)
131
+
132
+
133
+ def masker_classification_metrics(
134
+ pred, label, labels_dict={"cannot": 0, "must": 1, "may": 2}
135
+ ):
136
+ """
137
+ Classification metrics for the masker, and the corresponding maps. If the
138
+ predictions are soft, the errors are weighted accordingly. Metrics computed:
139
+
140
+ tpr : float
141
+ True positive rate
142
+
143
+ tpt : float
144
+ True positive total (divided by total population)
145
+
146
+ tnr : float
147
+ True negative rate
148
+
149
+ tnt : float
150
+ True negative total (divided by total population)
151
+
152
+ fpr : float
153
+ False positive rate: rate of predicted mask on cannot flood
154
+
155
+ fpt : float
156
+ False positive total (divided by total population)
157
+
158
+ fnr : float
159
+ False negative rate: rate of missed mask on must flood
160
+
161
+ fnt : float
162
+ False negative total (divided by total population)
163
+
164
+ mnr : float
165
+ "May" negative rate (labeled as "may", predicted as no-mask)
166
+
167
+ mpr : float
168
+ "May" positive rate (labeled as "may", predicted as mask)
169
+
170
+ accuracy : float
171
+ Accuracy
172
+
173
+ error : float
174
+ Error
175
+
176
+ precision : float
177
+ Precision, considering only cannot and must flood labels
178
+
179
+ f05 : float
180
+ F0.5 score, considering only cannot and must flood labels
181
+
182
+ accuracy_must_may : float
183
+ Accuracy considering only the must and may areas
184
+
185
+ Parameters
186
+ ----------
187
+ pred : array-like
188
+ Mask prediction
189
+
190
+ label : array-like
191
+ Mask ground truth labels
192
+
193
+ labels_dict : dict
194
+ A dictionary with the identifier of each class (cannot, must, may)
195
+
196
+ Returns
197
+ -------
198
+ metrics_dict : dict
199
+ A dictionary with metric name and value pairs
200
+
201
+ maps_dict : dict
202
+ A dictionary containing the metric maps
203
+ """
204
+ tp_map = pred * np.asarray(label == labels_dict["must"], dtype=int)
205
+ tpr = np.sum(tp_map) / np.sum(label == labels_dict["must"])
206
+ tpt = np.sum(tp_map) / np.prod(label.shape)
207
+ tn_map = (1.0 - pred) * np.asarray(label == labels_dict["cannot"], dtype=int)
208
+ tnr = np.sum(tn_map) / np.sum(label == labels_dict["cannot"])
209
+ tnt = np.sum(tn_map) / np.prod(label.shape)
210
+ fp_map = pred * np.asarray(label == labels_dict["cannot"], dtype=int)
211
+ fpr = np.sum(fp_map) / np.sum(label == labels_dict["cannot"])
212
+ fpt = np.sum(fp_map) / np.prod(label.shape)
213
+ fn_map = (1.0 - pred) * np.asarray(label == labels_dict["must"], dtype=int)
214
+ fnr = np.sum(fn_map) / np.sum(label == labels_dict["must"])
215
+ fnt = np.sum(fn_map) / np.prod(label.shape)
216
+ may_neg_map = (1.0 - pred) * np.asarray(label == labels_dict["may"], dtype=int)
217
+ may_pos_map = pred * np.asarray(label == labels_dict["may"], dtype=int)
218
+ mnr = np.sum(may_neg_map) / np.sum(label == labels_dict["may"])
219
+ mpr = np.sum(may_pos_map) / np.sum(label == labels_dict["may"])
220
+ accuracy = tpt + tnt
221
+ error = fpt + fnt
222
+
223
+ # Assertions
224
+ assert np.isclose(tpr, 1.0 - fnr), "TPR: {:.4f}, FNR: {:.4f}".format(tpr, fnr)
225
+ assert np.isclose(tnr, 1.0 - fpr), "TNR: {:.4f}, FPR: {:.4f}".format(tnr, fpr)
226
+ assert np.isclose(mpr, 1.0 - mnr), "MPR: {:.4f}, MNR: {:.4f}".format(mpr, mnr)
227
+
228
+ precision = np.sum(tp_map) / (np.sum(tp_map) + np.sum(fp_map) + 1e-9)
229
+ beta = 0.5
230
+ f05 = ((1 + beta ** 2) * precision * tpr) / (beta ** 2 * precision + tpr + 1e-9)
231
+ accuracy_must_may = (np.sum(tp_map) + np.sum(may_neg_map)) / (
232
+ np.sum(label == labels_dict["must"]) + np.sum(label == labels_dict["may"])
233
+ )
234
+
235
+ metrics_dict = {
236
+ "tpr": tpr,
237
+ "tpt": tpt,
238
+ "tnr": tnr,
239
+ "tnt": tnt,
240
+ "fpr": fpr,
241
+ "fpt": fpt,
242
+ "fnr": fnr,
243
+ "fnt": fnt,
244
+ "mpr": mpr,
245
+ "mnr": mnr,
246
+ "accuracy": accuracy,
247
+ "error": error,
248
+ "precision": precision,
249
+ "f05": f05,
250
+ "accuracy_must_may": accuracy_must_may,
251
+ }
252
+ maps_dict = {
253
+ "tp": tp_map,
254
+ "tn": tn_map,
255
+ "fp": fp_map,
256
+ "fn": fn_map,
257
+ "may_pos": may_pos_map,
258
+ "may_neg": may_neg_map,
259
+ }
260
+
261
+ return metrics_dict, maps_dict
262
+
263
+
264
+ def pred_cannot(pred, label, label_cannot=0):
265
+ """
266
+ Metric for the masker: Computes false positive rate and its map. If the
267
+ predictions are soft, the errors are weighted accordingly.
268
+
269
+ Parameters
270
+ ----------
271
+ pred : array-like
272
+ Mask prediction
273
+
274
+ label : array-like
275
+ Mask ground truth labels
276
+
277
+ label_cannot : int
278
+ The label index of "cannot flood"
279
+
280
+ Returns
281
+ -------
282
+ fp_map : array-like
283
+ The map of false positives: predicted mask on cannot flood
284
+
285
+ fpr : float
286
+ False positive rate: rate of predicted mask on cannot flood
287
+ """
288
+ fp_map = pred * np.asarray(label == label_cannot, dtype=int)
289
+ fpr = np.sum(fp_map) / np.sum(label == label_cannot)
290
+ return fp_map, fpr
291
+
292
+
293
+ def missed_must(pred, label, label_must=1):
294
+ """
295
+ Metric for the masker: Computes false negative rate and its map. If the
296
+ predictions are soft, the errors are weighted accordingly.
297
+
298
+ Parameters
299
+ ----------
300
+ pred : array-like
301
+ Mask prediction
302
+
303
+ label : array-like
304
+ Mask ground truth labels
305
+
306
+ label_must : int
307
+ The label index of "must flood"
308
+
309
+ Returns
310
+ -------
311
+ fn_map : array-like
312
+ The map of false negatives: missed mask on must flood
313
+
314
+ fnr : float
315
+ False negative rate: rate of missed mask on must flood
316
+ """
317
+ fn_map = (1.0 - pred) * np.asarray(label == label_must, dtype=int)
318
+ fnr = np.sum(fn_map) / np.sum(label == label_must)
319
+ return fn_map, fnr
320
+
321
+
322
+ def may_flood(pred, label, label_may=2):
323
+ """
324
+ Metric for the masker: Computes "may" negative and "may" positive rates and their
325
+ map. If the predictions are soft, the "errors" are weighted accordingly.
326
+
327
+ Parameters
328
+ ----------
329
+ pred : array-like
330
+ Mask prediction
331
+
332
+ label : array-like
333
+ Mask ground truth labels
334
+
335
+ label_may : int
336
+ The label index of "may flood"
337
+
338
+ Returns
339
+ -------
340
+ may_neg_map : array-like
341
+ The map of "may" negatives
342
+
343
+ may_pos_map : array-like
344
+ The map of "may" positives
345
+
346
+ mnr : float
347
+ "May" negative rate
348
+
349
+ mpr : float
350
+ "May" positive rate
351
+ """
352
+ may_neg_map = (1.0 - pred) * np.asarray(label == label_may, dtype=int)
353
+ may_pos_map = pred * np.asarray(label == label_may, dtype=int)
354
+ mnr = np.sum(may_neg_map) / np.sum(label == label_may)
355
+ mpr = np.sum(may_pos_map) / np.sum(label == label_may)
356
+ return may_neg_map, may_pos_map, mnr, mpr
357
+
358
+
359
+ def masker_metrics(pred, label, label_cannot=0, label_must=1):
360
+ """
361
+ Computes a set of metrics for the masker
362
+
363
+ Parameters
364
+ ----------
365
+ pred : array-like
366
+ Mask prediction
367
+
368
+ label : array-like
369
+ Mask ground truth labels
370
+
371
+ label_must : int
372
+ The label index of "must flood"
373
+
374
+ label_cannot : int
375
+ The label index of "cannot flood"
376
+
377
+ Returns
378
+ -------
379
+ tpr : float
380
+ True positive rate
381
+
382
+ tnr : float
383
+ True negative rate
384
+
385
+ precision : float
386
+ Precision, considering only cannot and must flood labels
387
+
388
+ f1 : float
389
+ F1 score, considering only cannot and must flood labels
390
+ """
391
+ tp_map = pred * np.asarray(label == label_must, dtype=int)
392
+ tpr = np.sum(tp_map) / np.sum(label == label_must)
393
+ tn_map = (1.0 - pred) * np.asarray(label == label_cannot, dtype=int)
394
+ tnr = np.sum(tn_map) / np.sum(label == label_cannot)
395
+ fp_map = pred * np.asarray(label == label_cannot, dtype=int)
396
+ fn_map = (1.0 - pred) * np.asarray(label == label_must, dtype=int) # noqa: F841
397
+ precision = np.sum(tp_map) / (np.sum(tp_map) + np.sum(fp_map))
398
+ f1 = 2 * (precision * tpr) / (precision + tpr)
399
+ return tpr, tnr, precision, f1
400
+
401
+
402
+ def get_confusion_matrix(tpr, tnr, fpr, fnr, mpr, mnr):
403
+ """
404
+ Constructs the confusion matrix of a masker prediction over a set of samples
405
+
406
+ Parameters
407
+ ----------
408
+ tpr : vector-like
409
+ True positive rate
410
+
411
+ tnr : vector-like
412
+ True negative rate
413
+
414
+ fpr : vector-like
415
+ False positive rate
416
+
417
+ fnr : vector-like
418
+ False negative rate
419
+
420
+ mpr : vector-like
421
+ "May" positive rate
422
+
423
+ mnr : vector-like
424
+ "May" negative rate
425
+
426
+ Returns
427
+ -------
428
+ confusion_matrix : 3x3 array
429
+ Confusion matrix: [i, j] = [pred, true]
430
+ | tnr fnr mnr |
431
+ | fpr tpr mpr |
432
+ | 0. 0, 0, |
433
+
434
+ confusion_matrix_std : 3x3 array
435
+ Standard deviation of the confusion matrix
436
+ """
437
+ # Compute mean and standard deviations over all samples
438
+ tpr_m = np.mean(tpr)
439
+ tpr_s = np.std(tpr)
440
+ tnr_m = np.mean(tnr)
441
+ tnr_s = np.std(tnr)
442
+ fpr_m = np.mean(fpr)
443
+ fpr_s = np.std(fpr)
444
+ fnr_m = np.mean(fnr)
445
+ fnr_s = np.std(fnr)
446
+ mpr_m = np.mean(mpr)
447
+ mpr_s = np.std(mpr)
448
+ mnr_m = np.mean(mnr)
449
+ mnr_s = np.std(mnr)
450
+
451
+ # Assertions
452
+ assert np.isclose(tpr_m, 1.0 - fnr_m), "TPR: {:.4f}, FNR: {:.4f}".format(
453
+ tpr_m, fnr_m
454
+ )
455
+ assert np.isclose(tnr_m, 1.0 - fpr_m), "TNR: {:.4f}, FPR: {:.4f}".format(
456
+ tnr_m, fpr_m
457
+ )
458
+ assert np.isclose(mpr_m, 1.0 - mnr_m), "MPR: {:.4f}, MNR: {:.4f}".format(
459
+ mpr_m, mnr_m
460
+ )
461
+
462
+ # Fill confusion matrix
463
+ confusion_matrix = np.zeros((3, 3))
464
+ confusion_matrix[0, 0] = tnr_m
465
+ confusion_matrix[0, 1] = fnr_m
466
+ confusion_matrix[0, 2] = mnr_m
467
+ confusion_matrix[1, 0] = fpr_m
468
+ confusion_matrix[1, 1] = tpr_m
469
+ confusion_matrix[1, 2] = mpr_m
470
+ confusion_matrix[2, 2] = 0.0
471
+
472
+ # Standard deviation
473
+ confusion_matrix_std = np.zeros((3, 3))
474
+ confusion_matrix_std[0, 0] = tnr_s
475
+ confusion_matrix_std[0, 1] = fnr_s
476
+ confusion_matrix_std[0, 2] = mnr_s
477
+ confusion_matrix_std[1, 0] = fpr_s
478
+ confusion_matrix_std[1, 1] = tpr_s
479
+ confusion_matrix_std[1, 2] = mpr_s
480
+ confusion_matrix_std[2, 2] = 0.0
481
+ return confusion_matrix, confusion_matrix_std
482
+
483
+
484
+ def edges_coherence_std_min(pred, label, label_must=1, bin_th=0.5):
485
+ """
486
+ The standard deviation of the minimum distance between the edge of the prediction
487
+ and the edge of the "must flood" label.
488
+
489
+ Parameters
490
+ ----------
491
+ pred : array-like
492
+ Mask prediction
493
+
494
+ label : array-like
495
+ Mask ground truth labels
496
+
497
+ label_must : int
498
+ The label index of "must flood"
499
+
500
+ bin_th : float
501
+ The threshold for the binarization of the prediction
502
+
503
+ Returns
504
+ -------
505
+ metric : float
506
+ The value of the metric
507
+
508
+ pred_edge : array-like
509
+ The edges images of the prediction, for visualization
510
+
511
+ label_edge : array-like
512
+ The edges images of the "must flood" label, for visualization
513
+ """
514
+ # Keep must flood label only
515
+ label = deepcopy(label)
516
+ label[label != label_must] = -1
517
+ label[label == label_must] = 1
518
+ label[label != label_must] = 0
519
+ label = np.asarray(label, dtype=float)
520
+
521
+ # Binarize prediction
522
+ pred = np.asarray(pred > bin_th, dtype=float)
523
+
524
+ # Compute edges
525
+ pred = filters.sobel(pred)
526
+ label = filters.sobel(label)
527
+
528
+ # Location of edges
529
+ pred_coord = np.argwhere(pred > 0)
530
+ label_coord = np.argwhere(label > 0)
531
+
532
+ # Handle blank predictions
533
+ if pred_coord.shape[0] == 0:
534
+ return 1.0, pred, label
535
+
536
+ # Normalized pairwise distances between pred and label
537
+ dist_mat = np.divide(euclidean_distances(pred_coord, label_coord), pred.shape[0])
538
+
539
+ # Standard deviation of the minimum distance from pred to label
540
+ edge_coherence = np.std(np.min(dist_mat, axis=1))
541
+
542
+ return edge_coherence, pred, label
543
+
544
+
545
+ def boxplot_metric(
546
+ output_filename,
547
+ df,
548
+ metric,
549
+ dict_metrics,
550
+ do_stripplot=False,
551
+ dict_models=None,
552
+ dpi=300,
553
+ **snskwargs
554
+ ):
555
+ f = plt.figure(dpi=dpi)
556
+
557
+ if do_stripplot:
558
+ ax = sns.boxplot(x="model", y=metric, data=df, fliersize=0.0, **snskwargs)
559
+ ax = sns.stripplot(
560
+ x="model", y=metric, data=df, size=2.0, color="gray", **snskwargs
561
+ )
562
+ else:
563
+ ax = sns.boxplot(x="model", y=metric, data=df, **snskwargs)
564
+
565
+ # Set axes labels
566
+ ax.set_xlabel("Models", rotation=0, fontsize="medium")
567
+ ax.set_ylabel(dict_metrics[metric], rotation=90, fontsize="medium")
568
+
569
+ # Spines
570
+ sns.despine(left=True, bottom=True)
571
+
572
+ # X-Tick labels
573
+ if dict_models:
574
+ xticklabels = [dict_models[t.get_text()] for t in ax.get_xticklabels()]
575
+ ax.set_xticklabels(
576
+ xticklabels,
577
+ rotation=20,
578
+ verticalalignment="top",
579
+ horizontalalignment="right",
580
+ fontsize="xx-small",
581
+ )
582
+
583
+ f.savefig(
584
+ output_filename,
585
+ dpi=f.dpi,
586
+ bbox_inches="tight",
587
+ facecolor="white",
588
+ transparent=False,
589
+ )
590
+ f.clear()
591
+ plt.close(f)
592
+
593
+
594
+ def clustermap_metric(
595
+ output_filename,
596
+ df,
597
+ metric,
598
+ dict_metrics,
599
+ method="average",
600
+ cluster_metric="euclidean",
601
+ dict_models=None,
602
+ dpi=300,
603
+ **snskwargs
604
+ ):
605
+ ax_grid = sns.clustermap(data=df, method=method, metric=cluster_metric, **snskwargs)
606
+ ax_heatmap = ax_grid.ax_heatmap
607
+ ax_cbar = ax_grid.ax_cbar
608
+
609
+ # Set axes labels
610
+ ax_heatmap.set_xlabel("Models", rotation=0, fontsize="medium")
611
+ ax_heatmap.set_ylabel("Images", rotation=90, fontsize="medium")
612
+
613
+ # Set title
614
+ ax_cbar.set_title(dict_metrics[metric], rotation=0, fontsize="x-large")
615
+
616
+ # X-Tick labels
617
+ if dict_models:
618
+ xticklabels = [dict_models[t.get_text()] for t in ax_heatmap.get_xticklabels()]
619
+ ax_heatmap.set_xticklabels(
620
+ xticklabels,
621
+ rotation=20,
622
+ verticalalignment="top",
623
+ horizontalalignment="right",
624
+ fontsize="small",
625
+ )
626
+
627
+ ax_grid.fig.savefig(
628
+ output_filename,
629
+ dpi=dpi,
630
+ bbox_inches="tight",
631
+ facecolor="white",
632
+ transparent=False,
633
+ )
634
+ ax_grid.fig.clear()
635
+ plt.close(ax_grid.fig)
climategan/fid.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://github.com/mseitzer/pytorch-fid
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision
8
+ from scipy import linalg
9
+ from torch.nn.functional import adaptive_avg_pool2d
10
+
11
+ try:
12
+ from torchvision.models.utils import load_state_dict_from_url
13
+ except ImportError:
14
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
15
+
16
+ FID_WEIGHTS_URL = (
17
+ "https://github.com/mseitzer/pytorch-fid/releases/download/"
18
+ + "fid_weights/pt_inception-2015-12-05-6726825d.pth"
19
+ )
20
+
21
+
22
+ class InceptionV3(nn.Module):
23
+ """Pretrained InceptionV3 network returning feature maps"""
24
+
25
+ # Index of default block of inception to return,
26
+ # corresponds to output of final average pooling
27
+ DEFAULT_BLOCK_INDEX = 3
28
+
29
+ # Maps feature dimensionality to their output blocks indices
30
+ BLOCK_INDEX_BY_DIM = {
31
+ 64: 0, # First max pooling features
32
+ 192: 1, # Second max pooling features
33
+ 768: 2, # Pre-aux classifier features
34
+ 2048: 3, # Final average pooling features
35
+ }
36
+
37
+ def __init__(
38
+ self,
39
+ output_blocks=[DEFAULT_BLOCK_INDEX],
40
+ resize_input=True,
41
+ normalize_input=True,
42
+ requires_grad=False,
43
+ use_fid_inception=True,
44
+ ):
45
+ """Build pretrained InceptionV3
46
+ Parameters
47
+ ----------
48
+ output_blocks : list of int
49
+ Indices of blocks to return features of. Possible values are:
50
+ - 0: corresponds to output of first max pooling
51
+ - 1: corresponds to output of second max pooling
52
+ - 2: corresponds to output which is fed to aux classifier
53
+ - 3: corresponds to output of final average pooling
54
+ resize_input : bool
55
+ If true, bilinearly resizes input to width and height 299 before
56
+ feeding input to model. As the network without fully connected
57
+ layers is fully convolutional, it should be able to handle inputs
58
+ of arbitrary size, so resizing might not be strictly needed
59
+ normalize_input : bool
60
+ If true, scales the input from range (0, 1) to the range the
61
+ pretrained Inception network expects, namely (-1, 1)
62
+ requires_grad : bool
63
+ If true, parameters of the model require gradients. Possibly useful
64
+ for finetuning the network
65
+ use_fid_inception : bool
66
+ If true, uses the pretrained Inception model used in Tensorflow's
67
+ FID implementation. If false, uses the pretrained Inception model
68
+ available in torchvision. The FID Inception model has different
69
+ weights and a slightly different structure from torchvision's
70
+ Inception model. If you want to compute FID scores, you are
71
+ strongly advised to set this parameter to true to get comparable
72
+ results.
73
+ """
74
+ super(InceptionV3, self).__init__()
75
+
76
+ self.resize_input = resize_input
77
+ self.normalize_input = normalize_input
78
+ self.output_blocks = sorted(output_blocks)
79
+ self.last_needed_block = max(output_blocks)
80
+
81
+ assert self.last_needed_block <= 3, "Last possible output block index is 3"
82
+
83
+ self.blocks = nn.ModuleList()
84
+
85
+ if use_fid_inception:
86
+ inception = fid_inception_v3()
87
+ else:
88
+ inception = _inception_v3(pretrained=True)
89
+
90
+ # Block 0: input to maxpool1
91
+ block0 = [
92
+ inception.Conv2d_1a_3x3,
93
+ inception.Conv2d_2a_3x3,
94
+ inception.Conv2d_2b_3x3,
95
+ nn.MaxPool2d(kernel_size=3, stride=2),
96
+ ]
97
+ self.blocks.append(nn.Sequential(*block0))
98
+
99
+ # Block 1: maxpool1 to maxpool2
100
+ if self.last_needed_block >= 1:
101
+ block1 = [
102
+ inception.Conv2d_3b_1x1,
103
+ inception.Conv2d_4a_3x3,
104
+ nn.MaxPool2d(kernel_size=3, stride=2),
105
+ ]
106
+ self.blocks.append(nn.Sequential(*block1))
107
+
108
+ # Block 2: maxpool2 to aux classifier
109
+ if self.last_needed_block >= 2:
110
+ block2 = [
111
+ inception.Mixed_5b,
112
+ inception.Mixed_5c,
113
+ inception.Mixed_5d,
114
+ inception.Mixed_6a,
115
+ inception.Mixed_6b,
116
+ inception.Mixed_6c,
117
+ inception.Mixed_6d,
118
+ inception.Mixed_6e,
119
+ ]
120
+ self.blocks.append(nn.Sequential(*block2))
121
+
122
+ # Block 3: aux classifier to final avgpool
123
+ if self.last_needed_block >= 3:
124
+ block3 = [
125
+ inception.Mixed_7a,
126
+ inception.Mixed_7b,
127
+ inception.Mixed_7c,
128
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
129
+ ]
130
+ self.blocks.append(nn.Sequential(*block3))
131
+
132
+ for param in self.parameters():
133
+ param.requires_grad = requires_grad
134
+
135
+ def forward(self, inp):
136
+ """Get Inception feature maps
137
+ Parameters
138
+ ----------
139
+ inp : torch.autograd.Variable
140
+ Input tensor of shape Bx3xHxW. Values are expected to be in
141
+ range (0, 1)
142
+ Returns
143
+ -------
144
+ List of torch.autograd.Variable, corresponding to the selected output
145
+ block, sorted ascending by index
146
+ """
147
+ outp = []
148
+ x = inp
149
+
150
+ if self.resize_input:
151
+ x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
152
+
153
+ if self.normalize_input:
154
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
155
+
156
+ for idx, block in enumerate(self.blocks):
157
+ x = block(x)
158
+ if idx in self.output_blocks:
159
+ outp.append(x)
160
+
161
+ if idx == self.last_needed_block:
162
+ break
163
+
164
+ return outp
165
+
166
+
167
+ def _inception_v3(*args, **kwargs):
168
+ """Wraps `torchvision.models.inception_v3`
169
+ Skips default weight initialization if supported by torchvision version.
170
+ See https://github.com/mseitzer/pytorch-fid/issues/28.
171
+ """
172
+ try:
173
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
174
+ except ValueError:
175
+ # Just a caution against weird version strings
176
+ version = (0,)
177
+
178
+ if version >= (0, 6):
179
+ kwargs["init_weights"] = False
180
+
181
+ return torchvision.models.inception_v3(*args, **kwargs)
182
+
183
+
184
+ def fid_inception_v3():
185
+ """Build pretrained Inception model for FID computation
186
+ The Inception model for FID computation uses a different set of weights
187
+ and has a slightly different structure than torchvision's Inception.
188
+ This method first constructs torchvision's Inception and then patches the
189
+ necessary parts that are different in the FID Inception model.
190
+ """
191
+ inception = _inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
192
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
193
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
194
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
195
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
196
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
197
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
198
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
199
+ inception.Mixed_7b = FIDInceptionE_1(1280)
200
+ inception.Mixed_7c = FIDInceptionE_2(2048)
201
+
202
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
203
+ inception.load_state_dict(state_dict)
204
+ return inception
205
+
206
+
207
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
208
+ """InceptionA block patched for FID computation"""
209
+
210
+ def __init__(self, in_channels, pool_features):
211
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
212
+
213
+ def forward(self, x):
214
+ branch1x1 = self.branch1x1(x)
215
+
216
+ branch5x5 = self.branch5x5_1(x)
217
+ branch5x5 = self.branch5x5_2(branch5x5)
218
+
219
+ branch3x3dbl = self.branch3x3dbl_1(x)
220
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
221
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
222
+
223
+ # Patch: Tensorflow's average pool does not use the padded zero's in
224
+ # its average calculation
225
+ branch_pool = F.avg_pool2d(
226
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
227
+ )
228
+ branch_pool = self.branch_pool(branch_pool)
229
+
230
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
231
+ return torch.cat(outputs, 1)
232
+
233
+
234
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
235
+ """InceptionC block patched for FID computation"""
236
+
237
+ def __init__(self, in_channels, channels_7x7):
238
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
239
+
240
+ def forward(self, x):
241
+ branch1x1 = self.branch1x1(x)
242
+
243
+ branch7x7 = self.branch7x7_1(x)
244
+ branch7x7 = self.branch7x7_2(branch7x7)
245
+ branch7x7 = self.branch7x7_3(branch7x7)
246
+
247
+ branch7x7dbl = self.branch7x7dbl_1(x)
248
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
249
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
250
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
251
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
252
+
253
+ # Patch: Tensorflow's average pool does not use the padded zero's in
254
+ # its average calculation
255
+ branch_pool = F.avg_pool2d(
256
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
257
+ )
258
+ branch_pool = self.branch_pool(branch_pool)
259
+
260
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261
+ return torch.cat(outputs, 1)
262
+
263
+
264
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265
+ """First InceptionE block patched for FID computation"""
266
+
267
+ def __init__(self, in_channels):
268
+ super(FIDInceptionE_1, self).__init__(in_channels)
269
+
270
+ def forward(self, x):
271
+ branch1x1 = self.branch1x1(x)
272
+
273
+ branch3x3 = self.branch3x3_1(x)
274
+ branch3x3 = [
275
+ self.branch3x3_2a(branch3x3),
276
+ self.branch3x3_2b(branch3x3),
277
+ ]
278
+ branch3x3 = torch.cat(branch3x3, 1)
279
+
280
+ branch3x3dbl = self.branch3x3dbl_1(x)
281
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
282
+ branch3x3dbl = [
283
+ self.branch3x3dbl_3a(branch3x3dbl),
284
+ self.branch3x3dbl_3b(branch3x3dbl),
285
+ ]
286
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
287
+
288
+ # Patch: Tensorflow's average pool does not use the padded zero's in
289
+ # its average calculation
290
+ branch_pool = F.avg_pool2d(
291
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
292
+ )
293
+ branch_pool = self.branch_pool(branch_pool)
294
+
295
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
296
+ return torch.cat(outputs, 1)
297
+
298
+
299
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
300
+ """Second InceptionE block patched for FID computation"""
301
+
302
+ def __init__(self, in_channels):
303
+ super(FIDInceptionE_2, self).__init__(in_channels)
304
+
305
+ def forward(self, x):
306
+ branch1x1 = self.branch1x1(x)
307
+
308
+ branch3x3 = self.branch3x3_1(x)
309
+ branch3x3 = [
310
+ self.branch3x3_2a(branch3x3),
311
+ self.branch3x3_2b(branch3x3),
312
+ ]
313
+ branch3x3 = torch.cat(branch3x3, 1)
314
+
315
+ branch3x3dbl = self.branch3x3dbl_1(x)
316
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
317
+ branch3x3dbl = [
318
+ self.branch3x3dbl_3a(branch3x3dbl),
319
+ self.branch3x3dbl_3b(branch3x3dbl),
320
+ ]
321
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
322
+
323
+ # Patch: The FID Inception model uses max pooling instead of average
324
+ # pooling. This is likely an error in this specific Inception
325
+ # implementation, as other Inception models use average pooling here
326
+ # (which matches the description in the paper).
327
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
328
+ branch_pool = self.branch_pool(branch_pool)
329
+
330
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
331
+ return torch.cat(outputs, 1)
332
+
333
+
334
+ def compute_val_fid(trainer, verbose=0):
335
+ """
336
+ Compute the fid score between the n=opts.train.fid.n_images real images
337
+ from the validation set (domain is rf) and n fake images pained from
338
+ those n validation images
339
+
340
+ Args:
341
+ trainer (climategan.Trainer): trainer to compute the val fid for
342
+
343
+ Returns:
344
+ float: FID score
345
+ """
346
+ # get opts params
347
+ batch_size = trainer.opts.train.fid.get("batch_size", 50)
348
+ dims = trainer.opts.train.fid.get("dims", 2048)
349
+
350
+ # set inception model
351
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
352
+ model = InceptionV3([block_idx]).to(trainer.device)
353
+
354
+ # first fid computation: compute the real stats, only once
355
+ if trainer.real_val_fid_stats is None:
356
+ if verbose > 0:
357
+ print("Computing real_val_fid_stats for the first time")
358
+ set_real_val_fid_stats(trainer, model, batch_size, dims)
359
+
360
+ # get real stats
361
+ real_m = trainer.real_val_fid_stats["m"]
362
+ real_s = trainer.real_val_fid_stats["s"]
363
+
364
+ # compute fake images
365
+ fakes = compute_fakes(trainer)
366
+ if verbose > 0:
367
+ print("Computing fake activation statistics")
368
+ # get fake stats
369
+ fake_m, fake_s = calculate_activation_statistics(
370
+ fakes, model, batch_size=batch_size, dims=dims, device=trainer.device
371
+ )
372
+ # compute FD between the real and the fake inception stats
373
+ return calculate_frechet_distance(real_m, real_s, fake_m, fake_s)
374
+
375
+
376
+ def set_real_val_fid_stats(trainer, model, batch_size, dims):
377
+ """
378
+ Sets the real_val_fid_stats attribute of the trainer with the m and
379
+ s outputs of calculate_activation_statistics on the real data.
380
+
381
+ This needs to be done only once since nothing changes during training here.
382
+
383
+ Args:
384
+ trainer (climategan.Trainer): trainer instance to compute the stats for
385
+ model (InceptionV3): inception model to get the activations from
386
+ batch_size (int): inception inference batch size
387
+ dims (int): dimension selected in the model
388
+ """
389
+ # in the rf domain display_size may be different from fid.n_images
390
+ limit = trainer.opts.train.fid.n_images
391
+ display_x = torch.stack(
392
+ [sample["data"]["x"] for sample in trainer.display_images["val"]["rf"][:limit]]
393
+ ).to(trainer.device)
394
+ m, s = calculate_activation_statistics(
395
+ display_x, model, batch_size=batch_size, dims=dims, device=trainer.device
396
+ )
397
+ trainer.real_val_fid_stats = {"m": m, "s": s}
398
+
399
+
400
+ def compute_fakes(trainer, verbose=0):
401
+ """
402
+ Compute current fake inferences
403
+
404
+ Args:
405
+ trainer (climategan.Trainer): trainer instance
406
+ verbose (int, optional): Print level. Defaults to 0.
407
+
408
+ Returns:
409
+ torch.Tensor: trainer.opts.train.fid.n_images painted images
410
+ """
411
+ # in the rf domain display_size may be different from fid.n_images
412
+ n = trainer.opts.train.fid.n_images
413
+ bs = trainer.opts.data.loaders.batch_size
414
+
415
+ display_batches = [
416
+ (sample["data"]["x"], sample["data"]["m"])
417
+ for sample in trainer.display_images["val"]["rf"][:n]
418
+ ]
419
+
420
+ display_x = torch.stack([b[0] for b in display_batches]).to(trainer.device)
421
+ display_m = torch.stack([b[0] for b in display_batches]).to(trainer.device)
422
+ nbs = len(display_x) // bs + 1
423
+
424
+ fakes = []
425
+ for b in range(nbs):
426
+ if verbose > 0:
427
+ print("computing fakes {}/{}".format(b + 1, nbs), end="\r", flush=True)
428
+ with torch.no_grad():
429
+ x = display_x[b * bs : (b + 1) * bs]
430
+ m = display_m[b * bs : (b + 1) * bs]
431
+ fake = trainer.G.paint(m, x)
432
+ fakes.append(fake)
433
+
434
+ return torch.cat(fakes, dim=0)
435
+
436
+
437
+ def calculate_activation_statistics(
438
+ images, model, batch_size=50, dims=2048, device="cpu"
439
+ ):
440
+ """Calculation of the statistics used by the FID.
441
+ Params:
442
+ -- images : List of images
443
+ -- model : Instance of inception model
444
+ -- batch_size : The images numpy array is split into batches with
445
+ batch size batch_size. A reasonable batch size
446
+ depends on the hardware.
447
+ -- dims : Dimensionality of features returned by Inception
448
+ -- device : Device to run calculations
449
+ Returns:
450
+ -- mu : The mean over samples of the activations of the pool_3 layer of
451
+ the inception model.
452
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
453
+ the inception model.
454
+ """
455
+ act = get_activations(images, model, batch_size, dims, device)
456
+ mu = np.mean(act, axis=0)
457
+ sigma = np.cov(act, rowvar=False)
458
+ return mu, sigma
459
+
460
+
461
+ def get_activations(images, model, batch_size=50, dims=2048, device="cpu"):
462
+ """Calculates the activations of the pool_3 layer for all images.
463
+ Params:
464
+ -- images : List of images
465
+ -- model : Instance of inception model
466
+ -- batch_size : Batch size of images for the model to process at once.
467
+ Make sure that the number of samples is a multiple of
468
+ the batch size, otherwise some samples are ignored. This
469
+ behavior is retained to match the original FID score
470
+ implementation.
471
+ -- dims : Dimensionality of features returned by Inception
472
+ -- device : Device to run calculations
473
+ Returns:
474
+ -- A numpy array of dimension (num images, dims) that contains the
475
+ activations of the given tensor when feeding inception with the
476
+ query tensor.
477
+ """
478
+ model.eval()
479
+
480
+ pred_arr = np.empty((len(images), dims))
481
+
482
+ start_idx = 0
483
+ nbs = len(images) // batch_size + 1
484
+
485
+ for b in range(nbs):
486
+ batch = images[b * batch_size : (b + 1) * batch_size].to(device)
487
+ if not batch.nelement():
488
+ continue
489
+
490
+ with torch.no_grad():
491
+ pred = model(batch)[0]
492
+
493
+ # If model output is not scalar, apply global spatial average pooling.
494
+ # This happens if you choose a dimensionality not equal 2048.
495
+ if pred.size(2) != 1 or pred.size(3) != 1:
496
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
497
+
498
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
499
+
500
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
501
+
502
+ start_idx = start_idx + pred.shape[0]
503
+
504
+ return pred_arr
505
+
506
+
507
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
508
+ """Numpy implementation of the Frechet Distance.
509
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
510
+ and X_2 ~ N(mu_2, C_2) is
511
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
512
+ Stable version by Dougal J. Sutherland.
513
+ Params:
514
+ -- mu1 : Numpy array containing the activations of a layer of the
515
+ inception net (like returned by the function 'get_predictions')
516
+ for generated samples.
517
+ -- mu2 : The sample mean over activations, precalculated on an
518
+ representative data set.
519
+ -- sigma1: The covariance matrix over activations for generated samples.
520
+ -- sigma2: The covariance matrix over activations, precalculated on an
521
+ representative data set.
522
+ Returns:
523
+ -- : The Frechet Distance.
524
+ """
525
+
526
+ mu1 = np.atleast_1d(mu1)
527
+ mu2 = np.atleast_1d(mu2)
528
+
529
+ sigma1 = np.atleast_2d(sigma1)
530
+ sigma2 = np.atleast_2d(sigma2)
531
+
532
+ assert (
533
+ mu1.shape == mu2.shape
534
+ ), "Training and test mean vectors have different lengths"
535
+ assert (
536
+ sigma1.shape == sigma2.shape
537
+ ), "Training and test covariances have different dimensions"
538
+
539
+ diff = mu1 - mu2
540
+
541
+ # Product might be almost singular
542
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
543
+ if not np.isfinite(covmean).all():
544
+ msg = (
545
+ "fid calculation produces singular product; "
546
+ "adding %s to diagonal of cov estimates"
547
+ ) % eps
548
+ print(msg)
549
+ offset = np.eye(sigma1.shape[0]) * eps
550
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
551
+
552
+ # Numerical error might give slight imaginary component
553
+ if np.iscomplexobj(covmean):
554
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
555
+ m = np.max(np.abs(covmean.imag))
556
+ raise ValueError("Imaginary component {}".format(m))
557
+ covmean = covmean.real
558
+
559
+ tr_covmean = np.trace(covmean)
560
+
561
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
climategan/fire.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import random
4
+ import kornia
5
+ from torchvision.transforms.functional import adjust_brightness, adjust_contrast
6
+
7
+ from climategan.tutils import normalize, retrieve_sky_mask
8
+
9
+ try:
10
+ from kornia.filters import filter2d
11
+ except ImportError:
12
+ from kornia.filters import filter2D as filter2d
13
+
14
+
15
+ def increase_sky_mask(mask, p_w=0, p_h=0):
16
+ """
17
+ Increases sky mask in width and height by a given pourcentage
18
+ (Purpose: when applying Gaussian blur, there are no artifacts of blue sky behind)
19
+ Args:
20
+ sky_mask (torch.Tensor): Sky mask of shape (H,W)
21
+ p_w (float): Percentage of mask width by which to increase
22
+ the width of the sky region
23
+ p_h (float): Percentage of mask height by which to increase
24
+ the height of the sky region
25
+ Returns:
26
+ torch.Tensor: Sky mask increased given p_w and p_h
27
+ """
28
+
29
+ if p_h <= 0 and p_w <= 0:
30
+ return mask
31
+
32
+ n_lines = int(p_h * mask.shape[-2])
33
+ n_cols = int(p_w * mask.shape[-1])
34
+
35
+ temp_mask = mask.clone().detach()
36
+ for i in range(1, n_cols):
37
+ temp_mask[:, :, :, i::] += mask[:, :, :, 0:-i]
38
+ temp_mask[:, :, :, 0:-i] += mask[:, :, :, i::]
39
+
40
+ new_mask = temp_mask.clone().detach()
41
+ for i in range(1, n_lines):
42
+ new_mask[:, :, i::, :] += temp_mask[:, :, 0:-i, :]
43
+ new_mask[:, :, 0:-i, :] += temp_mask[:, :, i::, :]
44
+
45
+ new_mask[new_mask >= 1] = 1
46
+
47
+ return new_mask
48
+
49
+
50
+ def paste_filter(x, filter_, mask):
51
+ """
52
+ Pastes a filter over an image given a mask
53
+ Where the mask is 1, the filter is copied as is.
54
+ Where the mask is 0, the current value is preserved.
55
+ Intermediate values will mix the two images together.
56
+ Args:
57
+ x (torch.Tensor): Input tensor, range must be [0, 255]
58
+ filer_ (torch.Tensor): Filter, range must be [0, 255]
59
+ mask (torch.Tensor): Mask, range must be [0, 1]
60
+ Returns:
61
+ torch.Tensor: New tensor with filter pasted on it
62
+ """
63
+ assert len(x.shape) == len(filter_.shape) == len(mask.shape)
64
+ x = filter_ * mask + x * (1 - mask)
65
+ return x
66
+
67
+
68
+ def add_fire(x, seg_preds, fire_opts):
69
+ """
70
+ Transforms input tensor given wildfires event
71
+ Args:
72
+ x (torch.Tensor): Input tensor
73
+ seg_preds (torch.Tensor): Semantic segmentation predictions for input tensor
74
+ filter_color (tuple): (r,g,b) tuple for the color of the sky
75
+ blur_radius (float): radius of the Gaussian blur that smooths
76
+ the transition between sky and foreground
77
+ Returns:
78
+ torch.Tensor: Wildfire version of input tensor
79
+ """
80
+ wildfire_tens = normalize(x, 0, 255)
81
+
82
+ # Warm the image
83
+ wildfire_tens[:, 2, :, :] -= 20
84
+ wildfire_tens[:, 1, :, :] -= 10
85
+ wildfire_tens[:, 0, :, :] += 40
86
+ wildfire_tens.clamp_(0, 255)
87
+ wildfire_tens = wildfire_tens.to(torch.uint8)
88
+
89
+ # Darken the picture and increase contrast
90
+ wildfire_tens = adjust_contrast(wildfire_tens, contrast_factor=1.5)
91
+ wildfire_tens = adjust_brightness(wildfire_tens, brightness_factor=0.73)
92
+
93
+ sky_mask = retrieve_sky_mask(seg_preds).unsqueeze(1)
94
+
95
+ if fire_opts.get("crop_bottom_sky_mask"):
96
+ i = 2 * sky_mask.shape[-2] // 3
97
+ sky_mask[..., i:, :] = 0
98
+
99
+ sky_mask = F.interpolate(
100
+ sky_mask.to(torch.float),
101
+ (wildfire_tens.shape[-2], wildfire_tens.shape[-1]),
102
+ )
103
+ sky_mask = increase_sky_mask(sky_mask, 0.18, 0.18)
104
+
105
+ kernel_size = (fire_opts.get("kernel_size", 301), fire_opts.get("kernel_size", 301))
106
+ sigma = (fire_opts.get("kernel_sigma", 150.5), fire_opts.get("kernel_sigma", 150.5))
107
+ border_type = "reflect"
108
+ kernel = torch.unsqueeze(
109
+ kornia.filters.kernels.get_gaussian_kernel2d(kernel_size, sigma), dim=0
110
+ ).to(x.device)
111
+ sky_mask = filter2d(sky_mask, kernel, border_type)
112
+
113
+ filter_ = torch.ones(wildfire_tens.shape, device=x.device)
114
+ filter_[:, 0, :, :] = 255
115
+ filter_[:, 1, :, :] = random.randint(100, 150)
116
+ filter_[:, 2, :, :] = 0
117
+
118
+ wildfire_tens = paste_tensor(wildfire_tens, filter_, sky_mask, 200)
119
+
120
+ wildfire_tens = adjust_brightness(wildfire_tens.to(torch.uint8), 0.8)
121
+ wildfire_tens = wildfire_tens.to(torch.float)
122
+
123
+ # dummy pixels to fool scaling and preserve range
124
+ wildfire_tens[:, :, 0, 0] = 255.0
125
+ wildfire_tens[:, :, -1, -1] = 0.0
126
+
127
+ return wildfire_tens
128
+
129
+
130
+ def paste_tensor(source, filter_, mask, transparency):
131
+ mask = transparency / 255.0 * mask
132
+ new = mask * filter_ + (1.0 - mask) * source
133
+ return new
climategan/generator.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Complete Generator architecture:
2
+ * OmniGenerator
3
+ * Encoder
4
+ * Decoders
5
+ """
6
+ from pathlib import Path
7
+ import traceback
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import yaml
13
+ from addict import Dict
14
+ from torch import softmax
15
+
16
+ import climategan.strings as strings
17
+ from climategan.deeplab import create_encoder, create_segmentation_decoder
18
+ from climategan.depth import create_depth_decoder
19
+ from climategan.masker import create_mask_decoder
20
+ from climategan.painter import create_painter
21
+ from climategan.tutils import init_weights, mix_noise, normalize
22
+
23
+
24
+ def create_generator(opts, device="cpu", latent_shape=None, no_init=False, verbose=0):
25
+ G = OmniGenerator(opts, latent_shape, verbose, no_init)
26
+ if no_init:
27
+ print("Sending to", device)
28
+ return G.to(device)
29
+
30
+ for model in G.decoders:
31
+ net = G.decoders[model]
32
+ if model == "s":
33
+ continue
34
+ if isinstance(net, nn.ModuleDict):
35
+ for domain, domain_model in net.items():
36
+ init_weights(
37
+ net[domain_model],
38
+ init_type=opts.gen[model].init_type,
39
+ init_gain=opts.gen[model].init_gain,
40
+ verbose=verbose,
41
+ caller=f"create_generator decoder {model} {domain}",
42
+ )
43
+ else:
44
+ init_weights(
45
+ G.decoders[model],
46
+ init_type=opts.gen[model].init_type,
47
+ init_gain=opts.gen[model].init_gain,
48
+ verbose=verbose,
49
+ caller=f"create_generator decoder {model}",
50
+ )
51
+ if G.encoder is not None and opts.gen.encoder.architecture == "base":
52
+ init_weights(
53
+ G.encoder,
54
+ init_type=opts.gen.encoder.init_type,
55
+ init_gain=opts.gen.encoder.init_gain,
56
+ verbose=verbose,
57
+ caller="create_generator encoder",
58
+ )
59
+
60
+ print("Sending to", device)
61
+ return G.to(device)
62
+
63
+
64
+ class OmniGenerator(nn.Module):
65
+ def __init__(self, opts, latent_shape=None, verbose=0, no_init=False):
66
+ """Creates the generator. All decoders listed in opts.gen will be added
67
+ to the Generator.decoders ModuleDict if opts.gen.DecoderInitial is not True.
68
+ Then can be accessed as G.decoders.T or G.decoders["T"] for instance,
69
+ for the image Translation decoder
70
+
71
+ Args:
72
+ opts (addict.Dict): configuration dict
73
+ """
74
+ super().__init__()
75
+ self.opts = opts
76
+ self.verbose = verbose
77
+ self.encoder = None
78
+ if any(t in opts.tasks for t in "msd"):
79
+ self.encoder = create_encoder(opts, no_init, verbose)
80
+
81
+ self.decoders = {}
82
+ self.painter = nn.Module()
83
+
84
+ if "d" in opts.tasks:
85
+ self.decoders["d"] = create_depth_decoder(opts, no_init, verbose)
86
+
87
+ if self.verbose > 0:
88
+ print(f" - Add {self.decoders['d'].__class__.__name__}")
89
+
90
+ if "s" in opts.tasks:
91
+ self.decoders["s"] = create_segmentation_decoder(opts, no_init, verbose)
92
+
93
+ if "m" in opts.tasks:
94
+ self.decoders["m"] = create_mask_decoder(opts, no_init, verbose)
95
+
96
+ self.decoders = nn.ModuleDict(self.decoders)
97
+
98
+ if "p" in self.opts.tasks:
99
+ self.painter = create_painter(opts, no_init, verbose)
100
+ else:
101
+ if self.verbose > 0:
102
+ print(" - Add Empty Painter")
103
+
104
+ @property
105
+ def device(self):
106
+ return next(self.parameters()).device
107
+
108
+ def __str__(self):
109
+ return strings.generator(self)
110
+
111
+ def encode(self, x):
112
+ """
113
+ Forward x through the encoder
114
+
115
+ Args:
116
+ x (torch.Tensor): B3HW input tensor
117
+
118
+ Returns:
119
+ list: High and Low level features from the encoder
120
+ """
121
+ assert self.encoder is not None
122
+ return self.encoder.forward(x)
123
+
124
+ def decode(self, x=None, z=None, return_z=False, return_z_depth=False):
125
+ """
126
+ Comptutes the predictions of all available decoders from either x or z.
127
+ If using spade for the masker with 15 channels, x *must* be provided,
128
+ whether z is too or not.
129
+
130
+ Args:
131
+ x (torch.Tensor, optional): Input tensor (B3HW). Defaults to None.
132
+ z (list, optional): List of high and low-level features as BCHW.
133
+ Defaults to None.
134
+ return_z (bool, optional): whether or not to return z in the dict.
135
+ Defaults to False.
136
+ return_z_depth (bool, optional): whether or not to return z_depth
137
+ in the dict. Defaults to False.
138
+
139
+ Raises:
140
+ ValueError: If using spade for the masker with 15 channels but x is None
141
+
142
+ Returns:
143
+ dict: {task: prediction_tensor} (may include z and z_depth
144
+ depending on args)
145
+ """
146
+
147
+ assert x is not None or z is not None
148
+ if self.opts.gen.m.use_spade and self.opts.m.spade.cond_nc == 15:
149
+ if x is None:
150
+ raise ValueError(
151
+ "When using spade for the Masker with 15 channels,"
152
+ + " x MUST be provided"
153
+ )
154
+
155
+ z_depth = cond = d = s = None
156
+ out = {}
157
+
158
+ if z is None:
159
+ z = self.encode(x)
160
+
161
+ if return_z:
162
+ out["z"] = z
163
+
164
+ if "d" in self.decoders:
165
+ d, z_depth = self.decoders["d"](z)
166
+ out["d"] = d
167
+
168
+ if return_z_depth:
169
+ out["z_depth"] = z_depth
170
+
171
+ if "s" in self.decoders:
172
+ s = self.decoders["s"](z, z_depth)
173
+ out["s"] = s
174
+
175
+ if "m" in self.decoders:
176
+ if s is not None and d is not None:
177
+ cond = self.make_m_cond(d, s, x)
178
+ m = self.mask(z=z, cond=cond)
179
+ out["m"] = m
180
+
181
+ return out
182
+
183
+ def sample_painter_z(self, batch_size, device, force_half=False):
184
+ if self.opts.gen.p.no_z:
185
+ return None
186
+
187
+ z = torch.empty(
188
+ batch_size,
189
+ self.opts.gen.p.latent_dim,
190
+ self.painter.z_h,
191
+ self.painter.z_w,
192
+ device=device,
193
+ ).normal_(mean=0, std=1.0)
194
+
195
+ if force_half:
196
+ z = z.half()
197
+
198
+ return z
199
+
200
+ def make_m_cond(self, d, s, x=None):
201
+ """
202
+ Create the masker's conditioning input when using spade from the
203
+ d and s predictions and from the input x when cond_nc == 15.
204
+
205
+ d and s are assumed to have the the same spatial resolution.
206
+ if cond_nc == 15 then x is interpolated to match that dimension.
207
+
208
+ Args:
209
+ d (torch.Tensor): Raw depth prediction (B1HW)
210
+ s (torch.Tensor): Raw segmentation prediction (BCHW)
211
+ x (torch.Tensor, optional): Input tensor (B3hW). Mandatory
212
+ when opts.gen.m.spade.cond_nc == 15
213
+
214
+ Raises:
215
+ ValueError: opts.gen.m.spade.cond_nc == 15 but x is None
216
+
217
+ Returns:
218
+ torch.Tensor: B x cond_nc x H x W conditioning tensor.
219
+ """
220
+ if self.opts.gen.m.spade.detach:
221
+ d = d.detach()
222
+ s = s.detach()
223
+ cats = [normalize(d), softmax(s, dim=1)]
224
+ if self.opts.gen.m.spade.cond_nc == 15:
225
+ if x is None:
226
+ raise ValueError(
227
+ "When using spade for the Masker with 15 channels,"
228
+ + " x MUST be provided"
229
+ )
230
+ cats += [
231
+ F.interpolate(x, s.shape[-2:], mode="bilinear", align_corners=True)
232
+ ]
233
+
234
+ return torch.cat(cats, dim=1)
235
+
236
+ def mask(self, x=None, z=None, cond=None, z_depth=None, sigmoid=True):
237
+ """
238
+ Create a mask from either an input x or a latent vector z.
239
+ Optionally if the Masker has a spade architecture the conditioning tensor
240
+ may be provided (cond). Default behavior applies an element-wise
241
+ sigmoid, but can be deactivated (sigmoid=False).
242
+
243
+ At least one of x or z must be provided (i.e. not None).
244
+ If the Masker has a spade architecture and cond_nc == 15 then x cannot
245
+ be None.
246
+
247
+ Args:
248
+ x (torch.Tensor, optional): Input tensor B3HW. Defaults to None.
249
+ z (list, optional): High and Low level features of the encoder.
250
+ Will be computed if None. Defaults to None.
251
+ cond ([type], optional): [description]. Defaults to None.
252
+ sigmoid (bool, optional): [description]. Defaults to True.
253
+
254
+ Returns:
255
+ torch.Tensor: B1HW mask tensor
256
+ """
257
+ assert x is not None or z is not None
258
+ if z is None:
259
+ z = self.encode(x)
260
+
261
+ if cond is None and self.opts.gen.m.use_spade:
262
+ assert "s" in self.opts.tasks and "d" in self.opts.tasks
263
+ with torch.no_grad():
264
+ d_pred, z_d = self.decoders["d"](z)
265
+ s_pred = self.decoders["s"](z, z_d)
266
+ cond = self.make_m_cond(d_pred, s_pred, x)
267
+ if z_depth is None and self.opts.gen.m.use_dada:
268
+ assert "d" in self.opts.tasks
269
+ with torch.no_grad():
270
+ _, z_depth = self.decoders["d"](z)
271
+
272
+ if cond is not None:
273
+ device = z[0].device if isinstance(z, (tuple, list)) else z.device
274
+ cond = cond.to(device)
275
+
276
+ logits = self.decoders["m"](z, cond, z_depth)
277
+
278
+ if not sigmoid:
279
+ return logits
280
+
281
+ return torch.sigmoid(logits)
282
+
283
+ def paint(self, m, x, no_paste=False):
284
+ """
285
+ Paints given a mask and an image
286
+ calls painter(z, x * (1.0 - m))
287
+ Mask has 1s where water should be painted
288
+
289
+ Args:
290
+ m (torch.Tensor): Mask
291
+ x (torch.Tensor): Image to paint
292
+
293
+ Returns:
294
+ torch.Tensor: painted image
295
+ """
296
+ z_paint = self.sample_painter_z(x.shape[0], x.device)
297
+ m = m.to(x.dtype)
298
+ fake = self.painter(z_paint, x * (1.0 - m))
299
+ if self.opts.gen.p.paste_original_content and not no_paste:
300
+ return x * (1.0 - m) + fake * m
301
+ return fake
302
+
303
+ def paint_cloudy(self, m, x, s, sky_idx=9, res=(8, 8), weight=0.8):
304
+ """
305
+ Paints x with water in m through an intermediary cloudy image
306
+ where the sky has been replaced with perlin noise to imitate clouds.
307
+
308
+ The intermediary cloudy image is only used to control the painter's
309
+ painting mode, probing it with a cloudy input.
310
+
311
+ Args:
312
+ m (torch.Tensor): water mask
313
+ x (torch.Tensor): input tensor
314
+ s (torch.Tensor): segmentation prediction (BCHW)
315
+ sky_idx (int, optional): Index of the sky class along s's C dimension.
316
+ Defaults to 9.
317
+ res (tuple, optional): Perlin noise spatial resolution. Defaults to (8, 8).
318
+ weight (float, optional): Intermediate image's cloud proportion
319
+ (w * cloud + (1-w) * original_sky). Defaults to 0.8.
320
+
321
+ Returns:
322
+ torch.Tensor: painted image with original content pasted.
323
+ """
324
+ sky_mask = (
325
+ torch.argmax(
326
+ F.interpolate(s, x.shape[-2:], mode="bilinear"), dim=1, keepdim=True
327
+ )
328
+ == sky_idx
329
+ ).to(x.dtype)
330
+ noised_x = mix_noise(x, sky_mask, res=res, weight=weight).to(x.dtype)
331
+ fake = self.paint(m, noised_x, no_paste=True)
332
+ return x * (1.0 - m) + fake * m
333
+
334
+ def depth(self, x=None, z=None, return_z_depth=False):
335
+ """
336
+ Compute the depth head's output
337
+
338
+ Args:
339
+ x (torch.Tensor, optional): Input B3HW tensor. Defaults to None.
340
+ z (list, optional): High and Low level features of the encoder.
341
+ Defaults to None.
342
+
343
+ Returns:
344
+ torch.Tensor: B1HW tensor of depth predictions
345
+ """
346
+ assert x is not None or z is not None
347
+ assert not (x is not None and z is not None)
348
+ if z is None:
349
+ z = self.encode(x)
350
+ depth, z_depth = self.decoders["d"](z)
351
+
352
+ if depth.shape[1] > 1:
353
+ depth = torch.argmax(depth, dim=1)
354
+ depth = depth / depth.max()
355
+
356
+ if return_z_depth:
357
+ return depth, z_depth
358
+
359
+ return depth
360
+
361
+ def load_val_painter(self):
362
+ """
363
+ Loads a validation painter if available in opts.val.val_painter
364
+
365
+ Returns:
366
+ bool: operation success status
367
+ """
368
+ try:
369
+ # key exists in opts
370
+ assert self.opts.val.val_painter
371
+
372
+ # path exists
373
+ ckpt_path = Path(self.opts.val.val_painter).resolve()
374
+ assert ckpt_path.exists()
375
+
376
+ # path is a checkpoint path
377
+ assert ckpt_path.is_file()
378
+
379
+ # opts are available in that path
380
+ opts_path = ckpt_path.parent.parent / "opts.yaml"
381
+ assert opts_path.exists()
382
+
383
+ # load opts
384
+ with opts_path.open("r") as f:
385
+ val_painter_opts = Dict(yaml.safe_load(f))
386
+
387
+ # load checkpoint
388
+ state_dict = torch.load(ckpt_path, map_location=self.device)
389
+
390
+ # create dummy painter from loaded opts
391
+ painter = create_painter(val_painter_opts)
392
+
393
+ # load state-dict in the dummy painter
394
+ painter.load_state_dict(
395
+ {k.replace("painter.", ""): v for k, v in state_dict["G"].items()}
396
+ )
397
+
398
+ # send to current device in evaluation mode
399
+ device = next(self.parameters()).device
400
+ self.painter = painter.eval().to(device)
401
+
402
+ # disable gradients
403
+ for p in self.painter.parameters():
404
+ p.requires_grad = False
405
+
406
+ # success
407
+ print(" - Loaded validation-only painter")
408
+ return True
409
+
410
+ except Exception as e:
411
+ # something happened, aborting gracefully
412
+ print(traceback.format_exc())
413
+ print(e)
414
+ print(">>> WARNING: error (^) in load_val_painter, aborting.")
415
+ return False
climategan/logger.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision.utils as vutils
6
+ from addict import Dict
7
+ from PIL import Image
8
+ from torch.nn.functional import interpolate, sigmoid
9
+
10
+ from climategan.data import decode_segmap_merged_labels
11
+ from climategan.tutils import (
12
+ all_texts_to_tensors,
13
+ decode_bucketed_depth,
14
+ normalize_tensor,
15
+ write_architecture,
16
+ )
17
+ from climategan.utils import flatten_opts
18
+
19
+
20
+ class Logger:
21
+ def __init__(self, trainer):
22
+ self.losses = Dict()
23
+ self.time = Dict()
24
+ self.trainer = trainer
25
+ self.global_step = 0
26
+ self.epoch = 0
27
+
28
+ def log_comet_images(self, mode, domain, minimal=False, all_only=False):
29
+ trainer = self.trainer
30
+ save_images = {}
31
+ all_images = []
32
+ n_all_ims = None
33
+ all_legends = ["Input"]
34
+ task_legends = {}
35
+
36
+ if domain not in trainer.display_images[mode]:
37
+ return
38
+
39
+ # --------------------
40
+ # ----- Masker -----
41
+ # --------------------
42
+ n_ims = len(trainer.display_images[mode][domain])
43
+ print(" " * 60, end="\r")
44
+ if domain != "rf":
45
+ for j, display_dict in enumerate(trainer.display_images[mode][domain]):
46
+
47
+ print(f"Inferring sample {mode} {domain} {j+1}/{n_ims}", end="\r")
48
+
49
+ x = display_dict["data"]["x"].unsqueeze(0).to(trainer.device)
50
+ z = trainer.G.encode(x)
51
+
52
+ s_pred = decoded_s_pred = d_pred = z_depth = None
53
+ for k, task in enumerate(["d", "s", "m"]):
54
+
55
+ if (
56
+ task not in display_dict["data"]
57
+ or task not in trainer.opts.tasks
58
+ ):
59
+ continue
60
+
61
+ task_legend = ["Input"]
62
+ target = display_dict["data"][task]
63
+ target = target.unsqueeze(0).to(trainer.device)
64
+ task_saves = []
65
+
66
+ if task not in save_images:
67
+ save_images[task] = []
68
+
69
+ prediction = None
70
+ if task == "m":
71
+ cond = None
72
+ if s_pred is not None and d_pred is not None:
73
+ cond = trainer.G.make_m_cond(d_pred, s_pred, x)
74
+
75
+ prediction = trainer.G.decoders[task](z, cond, z_depth)
76
+ elif task == "d":
77
+ prediction, z_depth = trainer.G.decoders[task](z)
78
+ elif task == "s":
79
+ prediction = trainer.G.decoders[task](z, z_depth)
80
+
81
+ if task == "s":
82
+ # Log fire
83
+ wildfire_tens = trainer.compute_fire(x, prediction)
84
+ task_saves.append(wildfire_tens)
85
+ task_legend.append("Wildfire")
86
+ # Log seg output
87
+ s_pred = prediction.clone()
88
+ target = (
89
+ decode_segmap_merged_labels(target, domain, True)
90
+ .float()
91
+ .to(trainer.device)
92
+ )
93
+ prediction = (
94
+ decode_segmap_merged_labels(prediction, domain, False)
95
+ .float()
96
+ .to(trainer.device)
97
+ )
98
+ decoded_s_pred = prediction
99
+ task_saves.append(target)
100
+ task_legend.append("Target Segmentation")
101
+
102
+ elif task == "m":
103
+ prediction = sigmoid(prediction).repeat(1, 3, 1, 1)
104
+ task_saves.append(x * (1.0 - prediction))
105
+ if not minimal:
106
+ task_saves.append(
107
+ x * (1.0 - (prediction > 0.1).to(torch.int))
108
+ )
109
+ task_saves.append(
110
+ x * (1.0 - (prediction > 0.5).to(torch.int))
111
+ )
112
+
113
+ task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1)))
114
+ task_legend.append("Masked input")
115
+
116
+ if not minimal:
117
+ task_legend.append("Masked input (>0.1)")
118
+ task_legend.append("Masked input (>0.5)")
119
+
120
+ task_legend.append("Masked input (target)")
121
+ # dummy pixels to fool scaling and preserve mask range
122
+ prediction[:, :, 0, 0] = 1.0
123
+ prediction[:, :, -1, -1] = 0.0
124
+
125
+ elif task == "d":
126
+ # prediction is a log depth tensor
127
+ d_pred = prediction
128
+ target = normalize_tensor(target) * 255
129
+ if prediction.shape[1] > 1:
130
+ prediction = decode_bucketed_depth(
131
+ prediction, self.trainer.opts
132
+ )
133
+ smogged = self.trainer.compute_smog(
134
+ x, d=prediction, s=decoded_s_pred, use_sky_seg=False
135
+ )
136
+ prediction = normalize_tensor(prediction)
137
+ prediction = prediction.repeat(1, 3, 1, 1)
138
+ task_saves.append(smogged)
139
+ task_legend.append("Smogged")
140
+ task_saves.append(target.repeat(1, 3, 1, 1))
141
+ task_legend.append("Depth target")
142
+
143
+ task_saves.append(prediction)
144
+ task_legend.append(f"Predicted {task}")
145
+
146
+ save_images[task].append(x.cpu().detach())
147
+ if k == 0:
148
+ all_images.append(save_images[task][-1])
149
+
150
+ task_legends[task] = task_legend
151
+ if j == 0:
152
+ all_legends += task_legend[1:]
153
+
154
+ for im in task_saves:
155
+ save_images[task].append(im.cpu().detach())
156
+ all_images.append(save_images[task][-1])
157
+
158
+ if j == 0:
159
+ n_all_ims = len(all_images)
160
+
161
+ if not all_only:
162
+ for task in save_images.keys():
163
+ # Write images:
164
+ self.upload_images(
165
+ image_outputs=save_images[task],
166
+ mode=mode,
167
+ domain=domain,
168
+ task=task,
169
+ im_per_row=trainer.opts.comet.im_per_row.get(task, 4),
170
+ rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
171
+ legends=task_legends[task],
172
+ )
173
+
174
+ if len(save_images) > 1:
175
+ self.upload_images(
176
+ image_outputs=all_images,
177
+ mode=mode,
178
+ domain=domain,
179
+ task="all",
180
+ im_per_row=n_all_ims,
181
+ rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
182
+ legends=all_legends,
183
+ )
184
+ # ---------------------
185
+ # ----- Painter -----
186
+ # ---------------------
187
+ else:
188
+ # in the rf domain display_size may be different from fid.n_images
189
+ limit = trainer.opts.comet.display_size
190
+ image_outputs = []
191
+ legends = []
192
+ for im_set in trainer.display_images[mode][domain][:limit]:
193
+ x = im_set["data"]["x"].unsqueeze(0).to(trainer.device)
194
+ m = im_set["data"]["m"].unsqueeze(0).to(trainer.device)
195
+
196
+ prediction = trainer.G.paint(m, x)
197
+
198
+ image_outputs.append(x * (1.0 - m))
199
+ image_outputs.append(prediction)
200
+ image_outputs.append(x)
201
+ image_outputs.append(prediction * m)
202
+ if not legends:
203
+ legends.append("Masked Input")
204
+ legends.append("Painted Input")
205
+ legends.append("Input")
206
+ legends.append("Isolated Water")
207
+ # Write images
208
+ self.upload_images(
209
+ image_outputs=image_outputs,
210
+ mode=mode,
211
+ domain=domain,
212
+ task="painter",
213
+ im_per_row=trainer.opts.comet.im_per_row.get("p", 4),
214
+ rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
215
+ legends=legends,
216
+ )
217
+
218
+ return 0
219
+
220
+ def log_losses(self, model_to_update="G", mode="train"):
221
+ """Logs metrics on comet.ml
222
+
223
+ Args:
224
+ model_to_update (str, optional): One of "G", "D". Defaults to "G".
225
+ """
226
+ trainer = self.trainer
227
+ loss_names = {"G": "gen", "D": "disc"}
228
+
229
+ if trainer.opts.train.log_level < 1:
230
+ return
231
+
232
+ if trainer.exp is None:
233
+ return
234
+
235
+ assert model_to_update in {
236
+ "G",
237
+ "D",
238
+ }, "unknown model to log losses {}".format(model_to_update)
239
+
240
+ loss_to_update = self.losses[loss_names[model_to_update]]
241
+
242
+ losses = loss_to_update.copy()
243
+
244
+ if trainer.opts.train.log_level == 1:
245
+ # Only log aggregated losses: delete other keys in losses
246
+ for k in loss_to_update:
247
+ if k not in {"masker", "total_loss", "painter"}:
248
+ del losses[k]
249
+ # convert losses into a single-level dictionnary
250
+
251
+ losses = flatten_opts(losses)
252
+ trainer.exp.log_metrics(
253
+ losses, prefix=f"{model_to_update}_{mode}", step=self.global_step
254
+ )
255
+
256
+ def log_learning_rates(self):
257
+ if self.trainer.exp is None:
258
+ return
259
+ lrs = {}
260
+ trainer = self.trainer
261
+ if trainer.g_scheduler is not None:
262
+ for name, lr in zip(
263
+ trainer.lr_names["G"], trainer.g_scheduler.get_last_lr()
264
+ ):
265
+ lrs[f"lr_G_{name}"] = lr
266
+ if trainer.d_scheduler is not None:
267
+ for name, lr in zip(
268
+ trainer.lr_names["D"], trainer.d_scheduler.get_last_lr()
269
+ ):
270
+ lrs[f"lr_D_{name}"] = lr
271
+
272
+ trainer.exp.log_metrics(lrs, step=self.global_step)
273
+
274
+ def log_step_time(self, time):
275
+ """Logs step-time on comet.ml
276
+
277
+ Args:
278
+ step_time (float): step-time in seconds
279
+ """
280
+ if self.trainer.exp:
281
+ self.trainer.exp.log_metric(
282
+ "step-time", time - self.time.step_start, step=self.global_step
283
+ )
284
+
285
+ def log_epoch_time(self, time):
286
+ """Logs step-time on comet.ml
287
+
288
+ Args:
289
+ step_time (float): step-time in seconds
290
+ """
291
+ if self.trainer.exp:
292
+ self.trainer.exp.log_metric(
293
+ "epoch-time", time - self.time.epoch_start, step=self.global_step
294
+ )
295
+
296
+ def log_comet_combined_images(self, mode, domain):
297
+
298
+ trainer = self.trainer
299
+ image_outputs = []
300
+ legends = []
301
+ im_per_row = 0
302
+ for i, im_set in enumerate(trainer.display_images[mode][domain]):
303
+ x = im_set["data"]["x"].unsqueeze(0).to(trainer.device)
304
+ # m = im_set["data"]["m"].unsqueeze(0).to(trainer.device)
305
+
306
+ m = trainer.G.mask(x=x)
307
+ m_bin = (m > 0.5).to(m.dtype)
308
+ prediction = trainer.G.paint(m, x)
309
+ prediction_bin = trainer.G.paint(m_bin, x)
310
+
311
+ image_outputs.append(x)
312
+ legends.append("Input")
313
+ image_outputs.append(x * (1.0 - m))
314
+ legends.append("Soft Masked Input")
315
+ image_outputs.append(prediction)
316
+ legends.append("Painted")
317
+ image_outputs.append(prediction * m)
318
+ legends.append("Soft Masked Painted")
319
+ image_outputs.append(x * (1.0 - m_bin))
320
+ legends.append("Binary (0.5) Masked Input")
321
+ image_outputs.append(prediction_bin)
322
+ legends.append("Binary (0.5) Painted")
323
+ image_outputs.append(prediction_bin * m_bin)
324
+ legends.append("Binary (0.5) Masked Painted")
325
+
326
+ if i == 0:
327
+ im_per_row = len(image_outputs)
328
+ # Upload images
329
+ self.upload_images(
330
+ image_outputs=image_outputs,
331
+ mode=mode,
332
+ domain=domain,
333
+ task="combined",
334
+ im_per_row=im_per_row or 7,
335
+ rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
336
+ legends=legends,
337
+ )
338
+
339
+ return 0
340
+
341
+ def upload_images(
342
+ self,
343
+ image_outputs,
344
+ mode,
345
+ domain,
346
+ task,
347
+ im_per_row=3,
348
+ rows_per_log=5,
349
+ legends=[],
350
+ ):
351
+ """
352
+ Save output image
353
+
354
+ Args:
355
+ image_outputs (list(torch.Tensor)): all the images to log
356
+ mode (str): train or val
357
+ domain (str): current domain
358
+ task (str): current task
359
+ im_per_row (int, optional): umber of images to be displayed per row.
360
+ Typically, for a given task: 3 because [input prediction, target].
361
+ Defaults to 3.
362
+ rows_per_log (int, optional): Number of rows (=samples) per uploaded image.
363
+ Defaults to 5.
364
+ comet_exp (comet_ml.Experiment, optional): experiment to use.
365
+ Defaults to None.
366
+ """
367
+ trainer = self.trainer
368
+ if trainer.exp is None:
369
+ return
370
+ curr_iter = self.global_step
371
+ nb_per_log = im_per_row * rows_per_log
372
+ n_logs = len(image_outputs) // nb_per_log + 1
373
+
374
+ header = None
375
+ if len(legends) == im_per_row and all(isinstance(t, str) for t in legends):
376
+ header_width = max(im.shape[-1] for im in image_outputs)
377
+ headers = all_texts_to_tensors(legends, width=header_width)
378
+ header = torch.cat(headers, dim=-1)
379
+
380
+ for logidx in range(n_logs):
381
+ print(" " * 100, end="\r", flush=True)
382
+ print(
383
+ "Uploading images for {} {} {} {}/{}".format(
384
+ mode, domain, task, logidx + 1, n_logs
385
+ ),
386
+ end="...",
387
+ flush=True,
388
+ )
389
+ ims = image_outputs[logidx * nb_per_log : (logidx + 1) * nb_per_log]
390
+ if not ims:
391
+ continue
392
+
393
+ ims = self.upsample(ims)
394
+ ims = torch.stack([im.squeeze() for im in ims]).squeeze()
395
+ image_grid = vutils.make_grid(
396
+ ims, nrow=im_per_row, normalize=True, scale_each=True, padding=0
397
+ )
398
+
399
+ if header is not None:
400
+ image_grid = torch.cat(
401
+ [header.to(image_grid.device), image_grid], dim=1
402
+ )
403
+
404
+ image_grid = image_grid.permute(1, 2, 0).cpu().numpy()
405
+ trainer.exp.log_image(
406
+ Image.fromarray((image_grid * 255).astype(np.uint8)),
407
+ name=f"{mode}_{domain}_{task}_{str(curr_iter)}_#{logidx}",
408
+ step=curr_iter,
409
+ )
410
+
411
+ def upsample(self, ims):
412
+ h = max(im.shape[-2] for im in ims)
413
+ w = max(im.shape[-1] for im in ims)
414
+ new_ims = []
415
+ for im in ims:
416
+ im = interpolate(im, (h, w), mode="bilinear")
417
+ new_ims.append(im)
418
+ return new_ims
419
+
420
+ def padd(self, ims):
421
+ h = max(im.shape[-2] for im in ims)
422
+ w = max(im.shape[-1] for im in ims)
423
+ new_ims = []
424
+ for im in ims:
425
+ ih = im.shape[-2]
426
+ iw = im.shape[-1]
427
+ if ih != h or iw != w:
428
+ padded = torch.zeros(im.shape[-3], h, w)
429
+ padded[
430
+ :, (h - ih) // 2 : (h + ih) // 2, (w - iw) // 2 : (w + iw) // 2
431
+ ] = im
432
+ new_ims.append(padded)
433
+ else:
434
+ new_ims.append(im)
435
+
436
+ return new_ims
437
+
438
+ def log_architecture(self):
439
+ write_architecture(self.trainer)
440
+
441
+ if self.trainer.exp is None:
442
+ return
443
+
444
+ for f in Path(self.trainer.opts.output_path).glob("archi*.txt"):
445
+ self.trainer.exp.log_asset(str(f), overwrite=True)
climategan/losses.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define all losses. When possible, as inheriting from nn.Module
2
+ To send predictions to target.device
3
+ """
4
+ from random import random as rand
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import models
11
+
12
+
13
+ class GANLoss(nn.Module):
14
+ def __init__(
15
+ self,
16
+ use_lsgan=True,
17
+ target_real_label=1.0,
18
+ target_fake_label=0.0,
19
+ soft_shift=0.0,
20
+ flip_prob=0.0,
21
+ verbose=0,
22
+ ):
23
+ """Defines the GAN loss which uses either LSGAN or the regular GAN.
24
+ When LSGAN is used, it is basically same as MSELoss,
25
+ but it abstracts away the need to create the target label tensor
26
+ that has the same size as the input +
27
+
28
+ * label smoothing: target_real_label=0.75
29
+ * label flipping: flip_prob > 0.
30
+
31
+ source: https://github.com/sangwoomo/instagan/blob
32
+ /b67e9008fcdd6c41652f8805f0b36bcaa8b632d6/models/networks.py
33
+
34
+ Args:
35
+ use_lsgan (bool, optional): Use MSE or BCE. Defaults to True.
36
+ target_real_label (float, optional): Value for the real target.
37
+ Defaults to 1.0.
38
+ target_fake_label (float, optional): Value for the fake target.
39
+ Defaults to 0.0.
40
+ flip_prob (float, optional): Probability of flipping the label
41
+ (use for real target in Discriminator only). Defaults to 0.0.
42
+ """
43
+ super().__init__()
44
+
45
+ self.soft_shift = soft_shift
46
+ self.verbose = verbose
47
+
48
+ self.register_buffer("real_label", torch.tensor(target_real_label))
49
+ self.register_buffer("fake_label", torch.tensor(target_fake_label))
50
+ if use_lsgan:
51
+ self.loss = nn.MSELoss()
52
+ else:
53
+ self.loss = nn.BCEWithLogitsLoss()
54
+ self.flip_prob = flip_prob
55
+
56
+ def get_target_tensor(self, input, target_is_real):
57
+ soft_change = torch.FloatTensor(1).uniform_(0, self.soft_shift)
58
+ if self.verbose > 0:
59
+ print("GANLoss sampled soft_change:", soft_change.item())
60
+ if target_is_real:
61
+ target_tensor = self.real_label - soft_change
62
+ else:
63
+ target_tensor = self.fake_label + soft_change
64
+ return target_tensor.expand_as(input)
65
+
66
+ def __call__(self, input, target_is_real, *args, **kwargs):
67
+ r = rand()
68
+ if isinstance(input, list):
69
+ loss = 0
70
+ for pred_i in input:
71
+ if isinstance(pred_i, list):
72
+ pred_i = pred_i[-1]
73
+ if r < self.flip_prob:
74
+ target_is_real = not target_is_real
75
+ target_tensor = self.get_target_tensor(pred_i, target_is_real)
76
+ loss_tensor = self.loss(pred_i, target_tensor.to(pred_i.device))
77
+ loss += loss_tensor
78
+ return loss / len(input)
79
+ else:
80
+ if r < self.flip_prob:
81
+ target_is_real = not target_is_real
82
+ target_tensor = self.get_target_tensor(input, target_is_real)
83
+ return self.loss(input, target_tensor.to(input.device))
84
+
85
+
86
+ class FeatMatchLoss(nn.Module):
87
+ def __init__(self):
88
+ super().__init__()
89
+ self.criterionFeat = nn.L1Loss()
90
+
91
+ def __call__(self, pred_real, pred_fake):
92
+ # pred_{real, fake} are lists of features
93
+ num_D = len(pred_fake)
94
+ GAN_Feat_loss = 0.0
95
+ for i in range(num_D): # for each discriminator
96
+ # last output is the final prediction, so we exclude it
97
+ num_intermediate_outputs = len(pred_fake[i]) - 1
98
+ for j in range(num_intermediate_outputs): # for each layer output
99
+ unweighted_loss = self.criterionFeat(
100
+ pred_fake[i][j], pred_real[i][j].detach()
101
+ )
102
+ GAN_Feat_loss += unweighted_loss / num_D
103
+ return GAN_Feat_loss
104
+
105
+
106
+ class CrossEntropy(nn.Module):
107
+ def __init__(self):
108
+ super().__init__()
109
+ self.loss = nn.CrossEntropyLoss()
110
+
111
+ def __call__(self, logits, target):
112
+ return self.loss(logits, target.to(logits.device).long())
113
+
114
+
115
+ class TravelLoss(nn.Module):
116
+ def __init__(self, eps=1e-12):
117
+ super().__init__()
118
+ self.eps = eps
119
+
120
+ def cosine_loss(self, real, fake):
121
+ norm_real = torch.norm(real, p=2, dim=1)[:, None]
122
+ norm_fake = torch.norm(fake, p=2, dim=1)[:, None]
123
+ mat_real = real / norm_real
124
+ mat_fake = fake / norm_fake
125
+ mat_real = torch.max(mat_real, self.eps * torch.ones_like(mat_real))
126
+ mat_fake = torch.max(mat_fake, self.eps * torch.ones_like(mat_fake))
127
+ # compute only the diagonal of the matrix multiplication
128
+ return torch.einsum("ij, ji -> i", mat_fake, mat_real).sum()
129
+
130
+ def __call__(self, S_real, S_fake):
131
+ self.v_real = []
132
+ self.v_fake = []
133
+ for i in range(len(S_real)):
134
+ for j in range(i):
135
+ self.v_real.append((S_real[i] - S_real[j])[None, :])
136
+ self.v_fake.append((S_fake[i] - S_fake[j])[None, :])
137
+ self.v_real_t = torch.cat(self.v_real, dim=0)
138
+ self.v_fake_t = torch.cat(self.v_fake, dim=0)
139
+ return self.cosine_loss(self.v_real_t, self.v_fake_t)
140
+
141
+
142
+ class TVLoss(nn.Module):
143
+ """Total Variational Regularization: Penalizes differences in
144
+ neighboring pixel values
145
+
146
+ source:
147
+ https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/TVLoss.py
148
+ """
149
+
150
+ def __init__(self, tvloss_weight=1):
151
+ """
152
+ Args:
153
+ TVLoss_weight (int, optional): [lambda i.e. weight for loss]. Defaults to 1.
154
+ """
155
+ super(TVLoss, self).__init__()
156
+ self.tvloss_weight = tvloss_weight
157
+
158
+ def forward(self, x):
159
+ batch_size = x.size()[0]
160
+ h_x = x.size()[2]
161
+ w_x = x.size()[3]
162
+ count_h = self._tensor_size(x[:, :, 1:, :])
163
+ count_w = self._tensor_size(x[:, :, :, 1:])
164
+ h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, : h_x - 1, :]), 2).sum()
165
+ w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, : w_x - 1]), 2).sum()
166
+ return self.tvloss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
167
+
168
+ def _tensor_size(self, t):
169
+ return t.size()[1] * t.size()[2] * t.size()[3]
170
+
171
+
172
+ class MinentLoss(nn.Module):
173
+ """
174
+ Loss for the minimization of the entropy map
175
+ Source for version 1: https://github.com/valeoai/ADVENT
176
+
177
+ Version 2 adds the variance of the entropy map in the computation of the loss
178
+ """
179
+
180
+ def __init__(self, version=1, lambda_var=0.1):
181
+ super().__init__()
182
+ self.version = version
183
+ self.lambda_var = lambda_var
184
+
185
+ def __call__(self, pred):
186
+ assert pred.dim() == 4
187
+ n, c, h, w = pred.size()
188
+ entropy_map = -torch.mul(pred, torch.log2(pred + 1e-30)) / np.log2(c)
189
+ if self.version == 1:
190
+ return torch.sum(entropy_map) / (n * h * w)
191
+ else:
192
+ entropy_map_demean = entropy_map - torch.sum(entropy_map) / (n * h * w)
193
+ entropy_map_squ = torch.mul(entropy_map_demean, entropy_map_demean)
194
+ return torch.sum(entropy_map + self.lambda_var * entropy_map_squ) / (
195
+ n * h * w
196
+ )
197
+
198
+
199
+ class MSELoss(nn.Module):
200
+ """
201
+ Creates a criterion that measures the mean squared error
202
+ (squared L2 norm) between each element in the input x and target y .
203
+ """
204
+
205
+ def __init__(self):
206
+ super().__init__()
207
+ self.loss = nn.MSELoss()
208
+
209
+ def __call__(self, prediction, target):
210
+ return self.loss(prediction, target.to(prediction.device))
211
+
212
+
213
+ class L1Loss(MSELoss):
214
+ """
215
+ Creates a criterion that measures the mean absolute error
216
+ (MAE) between each element in the input x and target y
217
+ """
218
+
219
+ def __init__(self):
220
+ super().__init__()
221
+ self.loss = nn.L1Loss()
222
+
223
+
224
+ class SIMSELoss(nn.Module):
225
+ """Scale invariant MSE Loss"""
226
+
227
+ def __init__(self):
228
+ super(SIMSELoss, self).__init__()
229
+
230
+ def __call__(self, prediction, target):
231
+ d = prediction - target
232
+ diff = torch.mean(d * d)
233
+ relDiff = torch.mean(d) * torch.mean(d)
234
+ return diff - relDiff
235
+
236
+
237
+ class SIGMLoss(nn.Module):
238
+ """loss from MiDaS paper
239
+ MiDaS did not specify how the gradients were computed but we use Sobel
240
+ filters which approximate the derivative of an image.
241
+ """
242
+
243
+ def __init__(self, gmweight=0.5, scale=4, device="cuda"):
244
+ super(SIGMLoss, self).__init__()
245
+ self.gmweight = gmweight
246
+ self.sobelx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).to(device)
247
+ self.sobely = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(device)
248
+ self.scale = scale
249
+
250
+ def __call__(self, prediction, target):
251
+ # get disparities
252
+ # align both the prediction and the ground truth to have zero
253
+ # translation and unit scale
254
+ t_pred = torch.median(prediction)
255
+ t_targ = torch.median(target)
256
+ s_pred = torch.mean(torch.abs(prediction - t_pred))
257
+ s_targ = torch.mean(torch.abs(target - t_targ))
258
+ pred = (prediction - t_pred) / s_pred
259
+ targ = (target - t_targ) / s_targ
260
+
261
+ R = pred - targ
262
+
263
+ # get gradient map with sobel filters
264
+ batch_size = prediction.size()[0]
265
+ num_pix = prediction.size()[-1] * prediction.size()[-2]
266
+ sobelx = (self.sobelx).expand((batch_size, 1, -1, -1))
267
+ sobely = (self.sobely).expand((batch_size, 1, -1, -1))
268
+ gmLoss = 0 # gradient matching term
269
+ for k in range(self.scale):
270
+ R_ = F.interpolate(R, scale_factor=1 / 2 ** k)
271
+ Rx = F.conv2d(R_, sobelx, stride=1)
272
+ Ry = F.conv2d(R_, sobely, stride=1)
273
+ gmLoss += torch.sum(torch.abs(Rx) + torch.abs(Ry))
274
+ gmLoss = self.gmweight / num_pix * gmLoss
275
+ # scale invariant MSE
276
+ simseLoss = 0.5 / num_pix * torch.sum(torch.abs(R))
277
+ loss = simseLoss + gmLoss
278
+ return loss
279
+
280
+
281
+ class ContextLoss(nn.Module):
282
+ """
283
+ Masked L1 loss on non-water
284
+ """
285
+
286
+ def __call__(self, input, target, mask):
287
+ return torch.mean(torch.abs(torch.mul((input - target), 1 - mask)))
288
+
289
+
290
+ class ReconstructionLoss(nn.Module):
291
+ """
292
+ Masked L1 loss on water
293
+ """
294
+
295
+ def __call__(self, input, target, mask):
296
+ return torch.mean(torch.abs(torch.mul((input - target), mask)))
297
+
298
+
299
+ ##################################################################################
300
+ # VGG network definition
301
+ ##################################################################################
302
+
303
+ # Source: https://github.com/NVIDIA/pix2pixHD
304
+ class Vgg19(nn.Module):
305
+ def __init__(self, requires_grad=False):
306
+ super(Vgg19, self).__init__()
307
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
308
+ self.slice1 = nn.Sequential()
309
+ self.slice2 = nn.Sequential()
310
+ self.slice3 = nn.Sequential()
311
+ self.slice4 = nn.Sequential()
312
+ self.slice5 = nn.Sequential()
313
+ for x in range(2):
314
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
315
+ for x in range(2, 7):
316
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
317
+ for x in range(7, 12):
318
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
319
+ for x in range(12, 21):
320
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
321
+ for x in range(21, 30):
322
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
323
+ if not requires_grad:
324
+ for param in self.parameters():
325
+ param.requires_grad = False
326
+
327
+ def forward(self, X):
328
+ h_relu1 = self.slice1(X)
329
+ h_relu2 = self.slice2(h_relu1)
330
+ h_relu3 = self.slice3(h_relu2)
331
+ h_relu4 = self.slice4(h_relu3)
332
+ h_relu5 = self.slice5(h_relu4)
333
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
334
+ return out
335
+
336
+
337
+ # Source: https://github.com/NVIDIA/pix2pixHD
338
+ class VGGLoss(nn.Module):
339
+ def __init__(self, device):
340
+ super().__init__()
341
+ self.vgg = Vgg19().to(device).eval()
342
+ self.criterion = nn.L1Loss()
343
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
344
+
345
+ def forward(self, x, y):
346
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
347
+ loss = 0
348
+ for i in range(len(x_vgg)):
349
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
350
+ return loss
351
+
352
+
353
+ def get_losses(opts, verbose, device=None):
354
+ """Sets the loss functions to be used by G, D and C, as specified
355
+ in the opts and returns a dictionnary of losses:
356
+
357
+ losses = {
358
+ "G": {
359
+ "gan": {"a": ..., "t": ...},
360
+ "cycle": {"a": ..., "t": ...}
361
+ "auto": {"a": ..., "t": ...}
362
+ "tasks": {"h": ..., "d": ..., "s": ..., etc.}
363
+ },
364
+ "D": GANLoss,
365
+ "C": ...
366
+ }
367
+ """
368
+
369
+ losses = {
370
+ "G": {"a": {}, "p": {}, "tasks": {}},
371
+ "D": {"default": {}, "advent": {}},
372
+ "C": {},
373
+ }
374
+
375
+ # ------------------------------
376
+ # ----- Generator Losses -----
377
+ # ------------------------------
378
+
379
+ # painter losses
380
+ if "p" in opts.tasks:
381
+ losses["G"]["p"]["gan"] = (
382
+ HingeLoss()
383
+ if opts.gen.p.loss == "hinge"
384
+ else GANLoss(
385
+ use_lsgan=False,
386
+ soft_shift=opts.dis.soft_shift,
387
+ flip_prob=opts.dis.flip_prob,
388
+ )
389
+ )
390
+ losses["G"]["p"]["dm"] = MSELoss()
391
+ losses["G"]["p"]["vgg"] = VGGLoss(device)
392
+ losses["G"]["p"]["tv"] = TVLoss()
393
+ losses["G"]["p"]["context"] = ContextLoss()
394
+ losses["G"]["p"]["reconstruction"] = ReconstructionLoss()
395
+ losses["G"]["p"]["featmatch"] = FeatMatchLoss()
396
+
397
+ # depth losses
398
+ if "d" in opts.tasks:
399
+ if not opts.gen.d.classify.enable:
400
+ if opts.gen.d.loss == "dada":
401
+ depth_func = DADADepthLoss()
402
+ else:
403
+ depth_func = SIGMLoss(opts.train.lambdas.G.d.gml)
404
+ else:
405
+ depth_func = CrossEntropy()
406
+
407
+ losses["G"]["tasks"]["d"] = depth_func
408
+
409
+ # segmentation losses
410
+ if "s" in opts.tasks:
411
+ losses["G"]["tasks"]["s"] = {}
412
+ losses["G"]["tasks"]["s"]["crossent"] = CrossEntropy()
413
+ losses["G"]["tasks"]["s"]["minent"] = MinentLoss()
414
+ losses["G"]["tasks"]["s"]["advent"] = ADVENTAdversarialLoss(
415
+ opts, gan_type=opts.dis.s.gan_type
416
+ )
417
+
418
+ # masker losses
419
+ if "m" in opts.tasks:
420
+ losses["G"]["tasks"]["m"] = {}
421
+ losses["G"]["tasks"]["m"]["bce"] = nn.BCEWithLogitsLoss()
422
+ if opts.gen.m.use_minent_var:
423
+ losses["G"]["tasks"]["m"]["minent"] = MinentLoss(
424
+ version=2, lambda_var=opts.train.lambdas.advent.ent_var
425
+ )
426
+ else:
427
+ losses["G"]["tasks"]["m"]["minent"] = MinentLoss()
428
+ losses["G"]["tasks"]["m"]["tv"] = TVLoss()
429
+ losses["G"]["tasks"]["m"]["advent"] = ADVENTAdversarialLoss(
430
+ opts, gan_type=opts.dis.m.gan_type
431
+ )
432
+ losses["G"]["tasks"]["m"]["gi"] = GroundIntersectionLoss()
433
+
434
+ # ----------------------------------
435
+ # ----- Discriminator Losses -----
436
+ # ----------------------------------
437
+ if "p" in opts.tasks:
438
+ losses["D"]["p"] = losses["G"]["p"]["gan"]
439
+ if "m" in opts.tasks or "s" in opts.tasks:
440
+ losses["D"]["advent"] = ADVENTAdversarialLoss(opts)
441
+ return losses
442
+
443
+
444
+ class GroundIntersectionLoss(nn.Module):
445
+ """
446
+ Penalize areas in ground seg but not in flood mask
447
+ """
448
+
449
+ def __call__(self, pred, pseudo_ground):
450
+ return torch.mean(1.0 * ((pseudo_ground - pred) > 0.5))
451
+
452
+
453
+ def prob_2_entropy(prob):
454
+ """
455
+ convert probabilistic prediction maps to weighted self-information maps
456
+ """
457
+ n, c, h, w = prob.size()
458
+ return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c)
459
+
460
+
461
+ class CustomBCELoss(nn.Module):
462
+ """
463
+ The first argument is a tensor and the second argument is an int.
464
+ There is no need to take sigmoid before calling this function.
465
+ """
466
+
467
+ def __init__(self):
468
+ super().__init__()
469
+ self.loss = nn.BCEWithLogitsLoss()
470
+
471
+ def __call__(self, prediction, target):
472
+ return self.loss(
473
+ prediction,
474
+ torch.FloatTensor(prediction.size())
475
+ .fill_(target)
476
+ .to(prediction.get_device()),
477
+ )
478
+
479
+
480
+ class ADVENTAdversarialLoss(nn.Module):
481
+ """
482
+ The class is for calculating the advent loss.
483
+ It is used to indirectly shrink the domain gap between sim and real
484
+
485
+ _call_ function:
486
+ prediction: torch.tensor with shape of [bs,c,h,w]
487
+ target: int; domain label: 0 (sim) or 1 (real)
488
+ discriminator: the discriminator model tells if a tensor is from sim or real
489
+
490
+ output: the loss value of GANLoss
491
+ """
492
+
493
+ def __init__(self, opts, gan_type="GAN"):
494
+ super().__init__()
495
+ self.opts = opts
496
+ if gan_type == "GAN":
497
+ self.loss = CustomBCELoss()
498
+ elif gan_type == "WGAN" or "WGAN_gp" or "WGAN_norm":
499
+ self.loss = lambda x, y: -torch.mean(y * x + (1 - y) * (1 - x))
500
+ else:
501
+ raise NotImplementedError
502
+
503
+ def __call__(self, prediction, target, discriminator, depth_preds=None):
504
+ """
505
+ Compute the GAN loss from the Advent Discriminator given
506
+ normalized (softmaxed) predictions (=pixel-wise class probabilities),
507
+ and int labels (target).
508
+
509
+ Args:
510
+ prediction (torch.Tensor): pixel-wise probability distribution over classes
511
+ target (torch.Tensor): pixel wise int target labels
512
+ discriminator (torch.nn.Module): Discriminator to get the loss
513
+
514
+ Returns:
515
+ torch.Tensor: float 0-D loss
516
+ """
517
+ d_out = prob_2_entropy(prediction)
518
+ if depth_preds is not None:
519
+ d_out = d_out * depth_preds
520
+ d_out = discriminator(d_out)
521
+ if self.opts.dis.m.architecture == "OmniDiscriminator":
522
+ d_out = multiDiscriminatorAdapter(d_out, self.opts)
523
+ loss_ = self.loss(d_out, target)
524
+ return loss_
525
+
526
+
527
+ def multiDiscriminatorAdapter(d_out: list, opts: dict) -> torch.tensor:
528
+ """
529
+ Because the OmniDiscriminator does not directly return a tensor
530
+ (but a list of tensor).
531
+ Since there is no multilevel masker, the 0th tensor in the list is all we want.
532
+ This Adapter returns the first element(tensor) of the list that OmniDiscriminator
533
+ returns.
534
+ """
535
+ if (
536
+ isinstance(d_out, list) and len(d_out) == 1
537
+ ): # adapt the multi-scale OmniDiscriminator
538
+ if not opts.dis.p.get_intermediate_features:
539
+ d_out = d_out[0][0]
540
+ else:
541
+ d_out = d_out[0]
542
+ else:
543
+ raise Exception(
544
+ "Check the setting of OmniDiscriminator! "
545
+ + "For now, we don't support multi-scale OmniDiscriminator."
546
+ )
547
+ return d_out
548
+
549
+
550
+ class HingeLoss(nn.Module):
551
+ """
552
+ Adapted from https://github.com/NVlabs/SPADE/blob/master/models/networks/loss.py
553
+ for the painter
554
+ """
555
+
556
+ def __init__(self, tensor=torch.FloatTensor):
557
+ super().__init__()
558
+ self.zero_tensor = None
559
+ self.Tensor = tensor
560
+
561
+ def get_zero_tensor(self, input):
562
+ if self.zero_tensor is None:
563
+ self.zero_tensor = self.Tensor(1).fill_(0)
564
+ self.zero_tensor.requires_grad_(False)
565
+ self.zero_tensor = self.zero_tensor.to(input.device)
566
+ return self.zero_tensor.expand_as(input)
567
+
568
+ def loss(self, input, target_is_real, for_discriminator=True):
569
+ if for_discriminator:
570
+ if target_is_real:
571
+ minval = torch.min(input - 1, self.get_zero_tensor(input))
572
+ loss = -torch.mean(minval)
573
+ else:
574
+ minval = torch.min(-input - 1, self.get_zero_tensor(input))
575
+ loss = -torch.mean(minval)
576
+ else:
577
+ assert target_is_real, "The generator's hinge loss must be aiming for real"
578
+ loss = -torch.mean(input)
579
+ return loss
580
+
581
+ def __call__(self, input, target_is_real, for_discriminator=True):
582
+ # computing loss is a bit complicated because |input| may not be
583
+ # a tensor, but list of tensors in case of multiscale discriminator
584
+ if isinstance(input, list):
585
+ loss = 0
586
+ for pred_i in input:
587
+ if isinstance(pred_i, list):
588
+ pred_i = pred_i[-1]
589
+ loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
590
+ loss += loss_tensor
591
+ return loss / len(input)
592
+ else:
593
+ return self.loss(input, target_is_real, for_discriminator)
594
+
595
+
596
+ class DADADepthLoss:
597
+ """Defines the reverse Huber loss from DADA paper for depth prediction
598
+ - Samples with larger residuals are penalized more by l2 term
599
+ - Samples with smaller residuals are penalized more by l1 term
600
+ From https://github.com/valeoai/DADA/blob/master/dada/utils/func.py
601
+ """
602
+
603
+ def loss_calc_depth(self, pred, label):
604
+ n, c, h, w = pred.size()
605
+ assert c == 1
606
+
607
+ pred = pred.squeeze()
608
+ label = label.squeeze()
609
+
610
+ adiff = torch.abs(pred - label)
611
+ batch_max = 0.2 * torch.max(adiff).item()
612
+ t1_mask = adiff.le(batch_max).float()
613
+ t2_mask = adiff.gt(batch_max).float()
614
+ t1 = adiff * t1_mask
615
+ t2 = (adiff * adiff + batch_max * batch_max) / (2 * batch_max)
616
+ t2 = t2 * t2_mask
617
+ return (torch.sum(t1) + torch.sum(t2)) / torch.numel(pred.data)
618
+
619
+ def __call__(self, pred, label):
620
+ return self.loss_calc_depth(pred, label)
climategan/masker.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from climategan.blocks import (
6
+ BaseDecoder,
7
+ Conv2dBlock,
8
+ InterpolateNearest2d,
9
+ SPADEResnetBlock,
10
+ )
11
+
12
+
13
+ def create_mask_decoder(opts, no_init=False, verbose=0):
14
+ if opts.gen.m.use_spade:
15
+ if verbose > 0:
16
+ print(" - Add Spade Mask Decoder")
17
+ assert "d" in opts.tasks or "s" in opts.tasks
18
+ return MaskSpadeDecoder(opts)
19
+ else:
20
+ if verbose > 0:
21
+ print(" - Add Base Mask Decoder")
22
+ return MaskBaseDecoder(opts)
23
+
24
+
25
+ class MaskBaseDecoder(BaseDecoder):
26
+ def __init__(self, opts):
27
+ low_level_feats_dim = -1
28
+ use_v3 = opts.gen.encoder.architecture == "deeplabv3"
29
+ use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet"
30
+ use_low = opts.gen.m.use_low_level_feats
31
+ use_dada = ("d" in opts.tasks) and opts.gen.m.use_dada
32
+
33
+ if use_v3 and use_mobile_net:
34
+ input_dim = 320
35
+ if use_low:
36
+ low_level_feats_dim = 24
37
+ elif use_v3:
38
+ input_dim = 2048
39
+ if use_low:
40
+ low_level_feats_dim = 256
41
+ else:
42
+ input_dim = 2048
43
+
44
+ super().__init__(
45
+ n_upsample=opts.gen.m.n_upsample,
46
+ n_res=opts.gen.m.n_res,
47
+ input_dim=input_dim,
48
+ proj_dim=opts.gen.m.proj_dim,
49
+ output_dim=opts.gen.m.output_dim,
50
+ norm=opts.gen.m.norm,
51
+ activ=opts.gen.m.activ,
52
+ pad_type=opts.gen.m.pad_type,
53
+ output_activ="none",
54
+ low_level_feats_dim=low_level_feats_dim,
55
+ use_dada=use_dada,
56
+ )
57
+
58
+
59
+ class MaskSpadeDecoder(nn.Module):
60
+ def __init__(self, opts):
61
+ """Create a SPADE-based decoder, which forwards z and the conditioning
62
+ tensors seg (in the original paper, conditioning is on a semantic map only).
63
+ All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
64
+ the channel dimension, and an upsampling is applied after each. Therefore
65
+ 2 upsamplings at this point. Then, for each remaining upsamplings
66
+ (w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
67
+ channels, the number of channels is therefore:
68
+ final_nc = channels(z) * 2 ** (spade_n_up - 2)
69
+ Args:
70
+ latent_dim (tuple): z's shape (only the number of channels matters)
71
+ cond_nc (int): conditioning tensor's expected number of channels
72
+ spade_n_up (int): Number of total upsamplings from z
73
+ spade_use_spectral_norm (bool): use spectral normalization?
74
+ spade_param_free_norm (str): norm to use before SPADE de-normalization
75
+ spade_kernel_size (int): SPADE conv layers' kernel size
76
+ Returns:
77
+ [type]: [description]
78
+ """
79
+ super().__init__()
80
+ self.opts = opts
81
+ latent_dim = opts.gen.m.spade.latent_dim
82
+ cond_nc = opts.gen.m.spade.cond_nc
83
+ spade_use_spectral_norm = opts.gen.m.spade.spade_use_spectral_norm
84
+ spade_param_free_norm = opts.gen.m.spade.spade_param_free_norm
85
+ if self.opts.gen.m.spade.activations.all_lrelu:
86
+ spade_activation = "lrelu"
87
+ else:
88
+ spade_activation = None
89
+ spade_kernel_size = 3
90
+ self.num_layers = opts.gen.m.spade.num_layers
91
+ self.z_nc = latent_dim
92
+
93
+ if (
94
+ opts.gen.encoder.architecture == "deeplabv3"
95
+ and opts.gen.deeplabv3.backbone == "mobilenet"
96
+ ):
97
+ self.input_dim = [320, 24]
98
+ self.low_level_conv = Conv2dBlock(
99
+ self.input_dim[1],
100
+ self.input_dim[0],
101
+ 3,
102
+ padding=1,
103
+ activation="lrelu",
104
+ pad_type="reflect",
105
+ norm="spectral_batch",
106
+ )
107
+ self.merge_feats_conv = Conv2dBlock(
108
+ self.input_dim[0] * 2,
109
+ self.z_nc,
110
+ 3,
111
+ padding=1,
112
+ activation="lrelu",
113
+ pad_type="reflect",
114
+ norm="spectral_batch",
115
+ )
116
+ elif (
117
+ opts.gen.encoder.architecture == "deeplabv3"
118
+ and opts.gen.deeplabv3.backbone == "resnet"
119
+ ):
120
+ self.input_dim = [2048, 256]
121
+ if self.opts.gen.m.use_proj:
122
+ proj_dim = self.opts.gen.m.proj_dim
123
+ self.low_level_conv = Conv2dBlock(
124
+ self.input_dim[1],
125
+ proj_dim,
126
+ 3,
127
+ padding=1,
128
+ activation="lrelu",
129
+ pad_type="reflect",
130
+ norm="spectral_batch",
131
+ )
132
+ self.high_level_conv = Conv2dBlock(
133
+ self.input_dim[0],
134
+ proj_dim,
135
+ 3,
136
+ padding=1,
137
+ activation="lrelu",
138
+ pad_type="reflect",
139
+ norm="spectral_batch",
140
+ )
141
+ self.merge_feats_conv = Conv2dBlock(
142
+ proj_dim * 2,
143
+ self.z_nc,
144
+ 3,
145
+ padding=1,
146
+ activation="lrelu",
147
+ pad_type="reflect",
148
+ norm="spectral_batch",
149
+ )
150
+ else:
151
+ self.low_level_conv = Conv2dBlock(
152
+ self.input_dim[1],
153
+ self.input_dim[0],
154
+ 3,
155
+ padding=1,
156
+ activation="lrelu",
157
+ pad_type="reflect",
158
+ norm="spectral_batch",
159
+ )
160
+ self.merge_feats_conv = Conv2dBlock(
161
+ self.input_dim[0] * 2,
162
+ self.z_nc,
163
+ 3,
164
+ padding=1,
165
+ activation="lrelu",
166
+ pad_type="reflect",
167
+ norm="spectral_batch",
168
+ )
169
+
170
+ elif opts.gen.encoder.architecture == "deeplabv2":
171
+ self.input_dim = 2048
172
+ self.fc_conv = Conv2dBlock(
173
+ self.input_dim,
174
+ self.z_nc,
175
+ 3,
176
+ padding=1,
177
+ activation="lrelu",
178
+ pad_type="reflect",
179
+ norm="spectral_batch",
180
+ )
181
+ else:
182
+ raise ValueError("Unknown encoder type")
183
+
184
+ self.spade_blocks = []
185
+
186
+ for i in range(self.num_layers):
187
+ self.spade_blocks.append(
188
+ SPADEResnetBlock(
189
+ int(self.z_nc / (2**i)),
190
+ int(self.z_nc / (2 ** (i + 1))),
191
+ cond_nc,
192
+ spade_use_spectral_norm,
193
+ spade_param_free_norm,
194
+ spade_kernel_size,
195
+ spade_activation,
196
+ )
197
+ )
198
+ self.spade_blocks = nn.Sequential(*self.spade_blocks)
199
+
200
+ self.final_nc = int(self.z_nc / (2**self.num_layers))
201
+ self.mask_conv = Conv2dBlock(
202
+ self.final_nc,
203
+ 1,
204
+ 3,
205
+ padding=1,
206
+ activation="none",
207
+ pad_type="reflect",
208
+ norm="spectral",
209
+ )
210
+ self.upsample = InterpolateNearest2d(scale_factor=2)
211
+
212
+ def forward(self, z, cond, z_depth=None):
213
+ if isinstance(z, (list, tuple)):
214
+ z_h, z_l = z
215
+ if self.opts.gen.m.use_proj:
216
+ z_l = self.low_level_conv(z_l)
217
+ z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear")
218
+ z_h = self.high_level_conv(z_h)
219
+ else:
220
+ z_l = self.low_level_conv(z_l)
221
+ z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear")
222
+ z = torch.cat([z_h, z_l], axis=1)
223
+ y = self.merge_feats_conv(z)
224
+ else:
225
+ y = self.fc_conv(z)
226
+
227
+ for i in range(self.num_layers):
228
+ y = self.spade_blocks[i](y, cond)
229
+ y = self.upsample(y)
230
+ y = self.mask_conv(y)
231
+ return y
232
+
233
+ def __str__(self):
234
+ return "MaskerSpadeDecoder"
climategan/norms.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalization layers used in blocks
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class AdaptiveInstanceNorm2d(nn.Module):
9
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
10
+ super(AdaptiveInstanceNorm2d, self).__init__()
11
+ self.num_features = num_features
12
+ self.eps = eps
13
+ self.momentum = momentum
14
+ # weight and bias are dynamically assigned
15
+ self.weight = None
16
+ self.bias = None
17
+ # just dummy buffers, not used
18
+ self.register_buffer("running_mean", torch.zeros(num_features))
19
+ self.register_buffer("running_var", torch.ones(num_features))
20
+
21
+ def forward(self, x):
22
+ assert (
23
+ self.weight is not None and self.bias is not None
24
+ ), "Please assign weight and bias before calling AdaIN!"
25
+ b, c = x.size(0), x.size(1)
26
+ running_mean = self.running_mean.repeat(b)
27
+ running_var = self.running_var.repeat(b)
28
+
29
+ # Apply instance norm
30
+ x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
31
+
32
+ out = F.batch_norm(
33
+ x_reshaped,
34
+ running_mean,
35
+ running_var,
36
+ self.weight,
37
+ self.bias,
38
+ True,
39
+ self.momentum,
40
+ self.eps,
41
+ )
42
+
43
+ return out.view(b, c, *x.size()[2:])
44
+
45
+ def __repr__(self):
46
+ return self.__class__.__name__ + "(" + str(self.num_features) + ")"
47
+
48
+
49
+ class LayerNorm(nn.Module):
50
+ def __init__(self, num_features, eps=1e-5, affine=True):
51
+ super(LayerNorm, self).__init__()
52
+ self.num_features = num_features
53
+ self.affine = affine
54
+ self.eps = eps
55
+
56
+ if self.affine:
57
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
58
+ self.beta = nn.Parameter(torch.zeros(num_features))
59
+
60
+ def forward(self, x):
61
+ shape = [-1] + [1] * (x.dim() - 1)
62
+ # print(x.size())
63
+ if x.size(0) == 1:
64
+ # These two lines run much faster in pytorch 0.4
65
+ # than the two lines listed below.
66
+ mean = x.view(-1).mean().view(*shape)
67
+ std = x.view(-1).std().view(*shape)
68
+ else:
69
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
70
+ std = x.view(x.size(0), -1).std(1).view(*shape)
71
+
72
+ x = (x - mean) / (std + self.eps)
73
+
74
+ if self.affine:
75
+ shape = [1, -1] + [1] * (x.dim() - 2)
76
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
77
+ return x
78
+
79
+
80
+ def l2normalize(v, eps=1e-12):
81
+ return v / (v.norm() + eps)
82
+
83
+
84
+ class SpectralNorm(nn.Module):
85
+ """
86
+ Based on the paper "Spectral Normalization for Generative Adversarial Networks"
87
+ by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the
88
+ Pytorch implementation:
89
+ https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
90
+ """
91
+
92
+ def __init__(self, module, name="weight", power_iterations=1):
93
+ super().__init__()
94
+ self.module = module
95
+ self.name = name
96
+ self.power_iterations = power_iterations
97
+ if not self._made_params():
98
+ self._make_params()
99
+
100
+ def _update_u_v(self):
101
+ u = getattr(self.module, self.name + "_u")
102
+ v = getattr(self.module, self.name + "_v")
103
+ w = getattr(self.module, self.name + "_bar")
104
+
105
+ height = w.data.shape[0]
106
+ for _ in range(self.power_iterations):
107
+ v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
108
+ u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
109
+
110
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
111
+ sigma = u.dot(w.view(height, -1).mv(v))
112
+ setattr(self.module, self.name, w / sigma.expand_as(w))
113
+
114
+ def _made_params(self):
115
+ try:
116
+ u = getattr(self.module, self.name + "_u") # noqa: F841
117
+ v = getattr(self.module, self.name + "_v") # noqa: F841
118
+ w = getattr(self.module, self.name + "_bar") # noqa: F841
119
+ return True
120
+ except AttributeError:
121
+ return False
122
+
123
+ def _make_params(self):
124
+ w = getattr(self.module, self.name)
125
+
126
+ height = w.data.shape[0]
127
+ width = w.view(height, -1).data.shape[1]
128
+
129
+ u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
130
+ v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
131
+ u.data = l2normalize(u.data)
132
+ v.data = l2normalize(v.data)
133
+ w_bar = nn.Parameter(w.data)
134
+
135
+ del self.module._parameters[self.name]
136
+
137
+ self.module.register_parameter(self.name + "_u", u)
138
+ self.module.register_parameter(self.name + "_v", v)
139
+ self.module.register_parameter(self.name + "_bar", w_bar)
140
+
141
+ def forward(self, *args):
142
+ self._update_u_v()
143
+ return self.module.forward(*args)
144
+
145
+
146
+ class SPADE(nn.Module):
147
+ def __init__(self, param_free_norm_type, kernel_size, norm_nc, cond_nc):
148
+ super().__init__()
149
+
150
+ if param_free_norm_type == "instance":
151
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
152
+ # elif param_free_norm_type == "syncbatch":
153
+ # self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
154
+ elif param_free_norm_type == "batch":
155
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
156
+ else:
157
+ raise ValueError(
158
+ "%s is not a recognized param-free norm type in SPADE"
159
+ % param_free_norm_type
160
+ )
161
+
162
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
163
+ nhidden = 128
164
+
165
+ pw = kernel_size // 2
166
+ self.mlp_shared = nn.Sequential(
167
+ nn.Conv2d(cond_nc, nhidden, kernel_size=kernel_size, padding=pw), nn.ReLU()
168
+ )
169
+ self.mlp_gamma = nn.Conv2d(
170
+ nhidden, norm_nc, kernel_size=kernel_size, padding=pw
171
+ )
172
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernel_size, padding=pw)
173
+
174
+ def forward(self, x, segmap):
175
+ # Part 1. generate parameter-free normalized activations
176
+ normalized = self.param_free_norm(x)
177
+
178
+ # Part 2. produce scaling and bias conditioned on semantic map
179
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
180
+ actv = self.mlp_shared(segmap)
181
+ gamma = self.mlp_gamma(actv)
182
+ beta = self.mlp_beta(actv)
183
+ # apply scale and bias
184
+ out = normalized * (1 + gamma) + beta
185
+
186
+ return out
climategan/optim.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define ExtraAdam and schedulers
2
+ """
3
+ import math
4
+
5
+ import torch
6
+ from torch.optim import Adam, Optimizer, RMSprop, lr_scheduler
7
+ from torch_optimizer import NovoGrad, RAdam
8
+
9
+
10
+ def get_scheduler(optimizer, hyperparameters, iterations=-1):
11
+ """Get an optimizer's learning rate scheduler based on opts
12
+
13
+ Args:
14
+ optimizer (torch.Optimizer): optimizer for which to schedule the learning rate
15
+ hyperparameters (addict.Dict): configuration options
16
+ iterations (int, optional): The index of last epoch. Defaults to -1.
17
+ When last_epoch=-1, sets initial lr as lr.
18
+
19
+ Returns:
20
+ [type]: [description]
21
+ """
22
+
23
+ policy = hyperparameters.get("lr_policy")
24
+ lr_step_size = hyperparameters.get("lr_step_size")
25
+ lr_gamma = hyperparameters.get("lr_gamma")
26
+ milestones = hyperparameters.get("lr_milestones")
27
+
28
+ if policy is None or policy == "constant":
29
+ scheduler = None # constant scheduler
30
+ elif policy == "step":
31
+ scheduler = lr_scheduler.StepLR(
32
+ optimizer, step_size=lr_step_size, gamma=lr_gamma, last_epoch=iterations,
33
+ )
34
+ elif policy == "multi_step":
35
+ if isinstance(milestones, (list, tuple)):
36
+ milestones = milestones
37
+ elif isinstance(milestones, int):
38
+ assert "lr_step_size" in hyperparameters
39
+ if iterations == -1:
40
+ last_milestone = 1000
41
+ else:
42
+ last_milestone = iterations
43
+ milestones = list(range(milestones, last_milestone, lr_step_size))
44
+ scheduler = lr_scheduler.MultiStepLR(
45
+ optimizer, milestones=milestones, gamma=lr_gamma, last_epoch=iterations,
46
+ )
47
+ else:
48
+ return NotImplementedError(
49
+ "learning rate policy [%s] is not implemented", hyperparameters["lr_policy"]
50
+ )
51
+ return scheduler
52
+
53
+
54
+ def get_optimizer(net, opt_conf, tasks=None, is_disc=False, iterations=-1):
55
+ """Returns a tuple (optimizer, scheduler) according to opt_conf which
56
+ should come from the trainer's opts as: trainer.opts.<model>.opt
57
+
58
+ Args:
59
+ net (nn.Module): Network to update
60
+ opt_conf (addict.Dict): optimizer and scheduler options
61
+ tasks: list of tasks
62
+ iterations (int, optional): Last epoch number. Defaults to -1, meaning
63
+ start with base lr.
64
+
65
+ Returns:
66
+ Tuple: (torch.Optimizer, torch._LRScheduler)
67
+ """
68
+ opt = scheduler = None
69
+ lr_names = []
70
+ if tasks is None:
71
+ lr_default = opt_conf.lr
72
+ params = net.parameters()
73
+ lr_names.append("full")
74
+ elif isinstance(opt_conf.lr, float): # Use default for all tasks
75
+ lr_default = opt_conf.lr
76
+ params = net.parameters()
77
+ lr_names.append("full")
78
+ elif len(opt_conf.lr) == 1: # Use default for all tasks
79
+ lr_default = opt_conf.lr.default
80
+ params = net.parameters()
81
+ lr_names.append("full")
82
+ else:
83
+ lr_default = opt_conf.lr.default
84
+ params = list()
85
+ for task in tasks:
86
+ lr = opt_conf.lr.get(task, lr_default)
87
+ parameters = None
88
+ # Parameters for encoder
89
+ if not is_disc:
90
+ if task == "m":
91
+ parameters = net.encoder.parameters()
92
+ params.append({"params": parameters, "lr": lr})
93
+ lr_names.append("encoder")
94
+ # Parameters for decoders
95
+ if task == "p":
96
+ if hasattr(net, "painter"):
97
+ parameters = net.painter.parameters()
98
+ lr_names.append("painter")
99
+ else:
100
+ parameters = net.decoders[task].parameters()
101
+ lr_names.append(f"decoder_{task}")
102
+ else:
103
+ if task in net:
104
+ parameters = net[task].parameters()
105
+ lr_names.append(f"disc_{task}")
106
+
107
+ if parameters is not None:
108
+ params.append({"params": parameters, "lr": lr})
109
+
110
+ if opt_conf.optimizer.lower() == "extraadam":
111
+ opt = ExtraAdam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
112
+ elif opt_conf.optimizer.lower() == "novograd":
113
+ opt = NovoGrad(
114
+ params, lr=lr_default, betas=(opt_conf.beta1, 0)
115
+ ) # default for beta2 is 0
116
+ elif opt_conf.optimizer.lower() == "radam":
117
+ opt = RAdam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
118
+ elif opt_conf.optimizer.lower() == "rmsprop":
119
+ opt = RMSprop(params, lr=lr_default)
120
+ else:
121
+ opt = Adam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
122
+ scheduler = get_scheduler(opt, opt_conf, iterations)
123
+ return opt, scheduler, lr_names
124
+
125
+
126
+ """
127
+ Extragradient Optimizer
128
+
129
+ Mostly copied from the extragrad paper repo.
130
+
131
+ MIT License
132
+ Copyright (c) Facebook, Inc. and its affiliates.
133
+ written by Hugo Berard (berard.hugo@gmail.com) while at Facebook.
134
+ """
135
+
136
+
137
+ class Extragradient(Optimizer):
138
+ """Base class for optimizers with extrapolation step.
139
+ Arguments:
140
+ params (iterable): an iterable of :class:`torch.Tensor` s or
141
+ :class:`dict` s. Specifies what Tensors should be optimized.
142
+ defaults: (dict): a dict containing default values of optimization
143
+ options (used when a parameter group doesn't specify them).
144
+ """
145
+
146
+ def __init__(self, params, defaults):
147
+ super(Extragradient, self).__init__(params, defaults)
148
+ self.params_copy = []
149
+
150
+ def update(self, p, group):
151
+ raise NotImplementedError
152
+
153
+ def extrapolation(self):
154
+ """Performs the extrapolation step and save a copy of the current
155
+ parameters for the update step.
156
+ """
157
+ # Check if a copy of the parameters was already made.
158
+ is_empty = len(self.params_copy) == 0
159
+ for group in self.param_groups:
160
+ for p in group["params"]:
161
+ u = self.update(p, group)
162
+ if is_empty:
163
+ # Save the current parameters for the update step.
164
+ # Several extrapolation step can be made before each update but
165
+ # only the parametersbefore the first extrapolation step are saved.
166
+ self.params_copy.append(p.data.clone())
167
+ if u is None:
168
+ continue
169
+ # Update the current parameters
170
+ p.data.add_(u)
171
+
172
+ def step(self, closure=None):
173
+ """Performs a single optimization step.
174
+ Arguments:
175
+ closure (callable, optional): A closure that reevaluates the model
176
+ and returns the loss.
177
+ """
178
+ if len(self.params_copy) == 0:
179
+ raise RuntimeError("Need to call extrapolation before calling step.")
180
+
181
+ loss = None
182
+ if closure is not None:
183
+ loss = closure()
184
+
185
+ i = -1
186
+ for group in self.param_groups:
187
+ for p in group["params"]:
188
+ i += 1
189
+ u = self.update(p, group)
190
+ if u is None:
191
+ continue
192
+ # Update the parameters saved during the extrapolation step
193
+ p.data = self.params_copy[i].add_(u)
194
+
195
+ # Free the old parameters
196
+ self.params_copy = []
197
+ return loss
198
+
199
+
200
+ class ExtraAdam(Extragradient):
201
+ """Implements the Adam algorithm with extrapolation step.
202
+ Arguments:
203
+ params (iterable): iterable of parameters to optimize or dicts defining
204
+ parameter groups
205
+ lr (float, optional): learning rate (default: 1e-3)
206
+ betas (Tuple[float, float], optional): coefficients used for computing
207
+ running averages of gradient and its square (default: (0.9, 0.999))
208
+ eps (float, optional): term added to the denominator to improve
209
+ numerical stability (default: 1e-8)
210
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
211
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
212
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ params,
218
+ lr=1e-3,
219
+ betas=(0.9, 0.999),
220
+ eps=1e-8,
221
+ weight_decay=0,
222
+ amsgrad=False,
223
+ ):
224
+ if not 0.0 <= lr:
225
+ raise ValueError("Invalid learning rate: {}".format(lr))
226
+ if not 0.0 <= eps:
227
+ raise ValueError("Invalid epsilon value: {}".format(eps))
228
+ if not 0.0 <= betas[0] < 1.0:
229
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
230
+ if not 0.0 <= betas[1] < 1.0:
231
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
232
+ defaults = dict(
233
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
234
+ )
235
+ super(ExtraAdam, self).__init__(params, defaults)
236
+
237
+ def __setstate__(self, state):
238
+ super(ExtraAdam, self).__setstate__(state)
239
+ for group in self.param_groups:
240
+ group.setdefault("amsgrad", False)
241
+
242
+ def update(self, p, group):
243
+ if p.grad is None:
244
+ return None
245
+ grad = p.grad.data
246
+ if grad.is_sparse:
247
+ raise RuntimeError(
248
+ "Adam does not support sparse gradients,"
249
+ + " please consider SparseAdam instead"
250
+ )
251
+ amsgrad = group["amsgrad"]
252
+
253
+ state = self.state[p]
254
+
255
+ # State initialization
256
+ if len(state) == 0:
257
+ state["step"] = 0
258
+ # Exponential moving average of gradient values
259
+ state["exp_avg"] = torch.zeros_like(p.data)
260
+ # Exponential moving average of squared gradient values
261
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
262
+ if amsgrad:
263
+ # Maintains max of all exp. moving avg. of sq. grad. values
264
+ state["max_exp_avg_sq"] = torch.zeros_like(p.data)
265
+
266
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
267
+ if amsgrad:
268
+ max_exp_avg_sq = state["max_exp_avg_sq"]
269
+ beta1, beta2 = group["betas"]
270
+
271
+ state["step"] += 1
272
+
273
+ if group["weight_decay"] != 0:
274
+ grad = grad.add(group["weight_decay"], p.data)
275
+
276
+ # Decay the first and second moment running average coefficient
277
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
278
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
279
+ if amsgrad:
280
+ # Maintains the maximum of all 2nd moment running avg. till now
281
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # type: ignore
282
+ # Use the max. for normalizing running avg. of gradient
283
+ denom = max_exp_avg_sq.sqrt().add_(group["eps"]) # type: ignore
284
+ else:
285
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
286
+
287
+ bias_correction1 = 1 - beta1 ** state["step"]
288
+ bias_correction2 = 1 - beta2 ** state["step"]
289
+ step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
290
+
291
+ return -step_size * exp_avg / denom
climategan/painter.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import climategan.strings as strings
6
+ from climategan.blocks import InterpolateNearest2d, SPADEResnetBlock
7
+ from climategan.norms import SpectralNorm
8
+
9
+
10
+ def create_painter(opts, no_init=False, verbose=0):
11
+ if verbose > 0:
12
+ print(" - Add PainterSpadeDecoder Painter")
13
+ return PainterSpadeDecoder(opts)
14
+
15
+
16
+ class PainterSpadeDecoder(nn.Module):
17
+ def __init__(self, opts):
18
+ """Create a SPADE-based decoder, which forwards z and the conditioning
19
+ tensors seg (in the original paper, conditioning is on a semantic map only).
20
+ All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
21
+ the channel dimension, and an upsampling is applied after each. Therefore
22
+ 2 upsamplings at this point. Then, for each remaining upsamplings
23
+ (w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
24
+ channels, the number of channels is therefore:
25
+ final_nc = channels(z) * 2 ** (spade_n_up - 2)
26
+ Args:
27
+ latent_dim (tuple): z's shape (only the number of channels matters)
28
+ cond_nc (int): conditioning tensor's expected number of channels
29
+ spade_n_up (int): Number of total upsamplings from z
30
+ spade_use_spectral_norm (bool): use spectral normalization?
31
+ spade_param_free_norm (str): norm to use before SPADE de-normalization
32
+ spade_kernel_size (int): SPADE conv layers' kernel size
33
+ Returns:
34
+ [type]: [description]
35
+ """
36
+ super().__init__()
37
+
38
+ latent_dim = opts.gen.p.latent_dim
39
+ cond_nc = 3
40
+ spade_n_up = opts.gen.p.spade_n_up
41
+ spade_use_spectral_norm = opts.gen.p.spade_use_spectral_norm
42
+ spade_param_free_norm = opts.gen.p.spade_param_free_norm
43
+ spade_kernel_size = 3
44
+
45
+ self.z_nc = latent_dim
46
+ self.spade_n_up = spade_n_up
47
+
48
+ self.z_h = self.z_w = None
49
+
50
+ self.fc = nn.Conv2d(3, latent_dim, 3, padding=1)
51
+ self.head_0 = SPADEResnetBlock(
52
+ self.z_nc,
53
+ self.z_nc,
54
+ cond_nc,
55
+ spade_use_spectral_norm,
56
+ spade_param_free_norm,
57
+ spade_kernel_size,
58
+ )
59
+
60
+ self.G_middle_0 = SPADEResnetBlock(
61
+ self.z_nc,
62
+ self.z_nc,
63
+ cond_nc,
64
+ spade_use_spectral_norm,
65
+ spade_param_free_norm,
66
+ spade_kernel_size,
67
+ )
68
+ self.G_middle_1 = SPADEResnetBlock(
69
+ self.z_nc,
70
+ self.z_nc,
71
+ cond_nc,
72
+ spade_use_spectral_norm,
73
+ spade_param_free_norm,
74
+ spade_kernel_size,
75
+ )
76
+
77
+ self.up_spades = nn.Sequential(
78
+ *[
79
+ SPADEResnetBlock(
80
+ self.z_nc // 2 ** i,
81
+ self.z_nc // 2 ** (i + 1),
82
+ cond_nc,
83
+ spade_use_spectral_norm,
84
+ spade_param_free_norm,
85
+ spade_kernel_size,
86
+ )
87
+ for i in range(spade_n_up - 2)
88
+ ]
89
+ )
90
+
91
+ self.final_nc = self.z_nc // 2 ** (spade_n_up - 2)
92
+
93
+ self.final_spade = SPADEResnetBlock(
94
+ self.final_nc,
95
+ self.final_nc,
96
+ cond_nc,
97
+ spade_use_spectral_norm,
98
+ spade_param_free_norm,
99
+ spade_kernel_size,
100
+ )
101
+ self.final_shortcut = None
102
+ if opts.gen.p.use_final_shortcut:
103
+ self.final_shortcut = nn.Sequential(
104
+ *[
105
+ SpectralNorm(nn.Conv2d(self.final_nc, 3, 1)),
106
+ nn.BatchNorm2d(3),
107
+ nn.LeakyReLU(0.2, True),
108
+ ]
109
+ )
110
+
111
+ self.conv_img = nn.Conv2d(self.final_nc, 3, 3, padding=1)
112
+
113
+ self.upsample = InterpolateNearest2d(scale_factor=2)
114
+
115
+ def set_latent_shape(self, shape, is_input=True):
116
+ """
117
+ Sets the latent shape to start the upsampling from, i.e. z_h and z_w.
118
+ If is_input is True, then this is the actual input shape which should
119
+ be divided by 2 ** spade_n_up
120
+ Otherwise, just sets z_h and z_w from shape[-2] and shape[-1]
121
+
122
+ Args:
123
+ shape (tuple): The shape to start sampling from.
124
+ is_input (bool, optional): Whether to divide shape by 2 ** spade_n_up
125
+ """
126
+ if isinstance(shape, (list, tuple)):
127
+ self.z_h = shape[-2]
128
+ self.z_w = shape[-1]
129
+ elif isinstance(shape, int):
130
+ self.z_h = self.z_w = shape
131
+ else:
132
+ raise ValueError("Unknown shape type:", shape)
133
+
134
+ if is_input:
135
+ self.z_h = self.z_h // (2 ** self.spade_n_up)
136
+ self.z_w = self.z_w // (2 ** self.spade_n_up)
137
+
138
+ def _apply(self, fn):
139
+ # print("Applying SpadeDecoder", fn)
140
+ super()._apply(fn)
141
+ # self.head_0 = fn(self.head_0)
142
+ # self.G_middle_0 = fn(self.G_middle_0)
143
+ # self.G_middle_1 = fn(self.G_middle_1)
144
+ # for i, up in enumerate(self.up_spades):
145
+ # self.up_spades[i] = fn(up)
146
+ # self.conv_img = fn(self.conv_img)
147
+ return self
148
+
149
+ def forward(self, z, cond):
150
+ if z is None:
151
+ assert self.z_h is not None and self.z_w is not None
152
+ z = self.fc(F.interpolate(cond, size=(self.z_h, self.z_w)))
153
+ y = self.head_0(z, cond)
154
+ y = self.upsample(y)
155
+ y = self.G_middle_0(y, cond)
156
+ y = self.upsample(y)
157
+ y = self.G_middle_1(y, cond)
158
+
159
+ for i, up in enumerate(self.up_spades):
160
+ y = self.upsample(y)
161
+ y = up(y, cond)
162
+
163
+ if self.final_shortcut is not None:
164
+ cond = self.final_shortcut(y)
165
+ y = self.final_spade(y, cond)
166
+ y = self.conv_img(F.leaky_relu(y, 2e-1))
167
+ y = torch.tanh(y)
168
+ return y
169
+
170
+ def __str__(self):
171
+ return strings.spadedecoder(self)
climategan/strings.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """custom __str__ methods for ClimateGAN's classes
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def title(name, color="\033[94m"):
8
+ name = "==== " + name + " ===="
9
+ s = "=" * len(name)
10
+ s = f"{s}\n{name}\n{s}"
11
+ return f"\033[1m{color}{s}\033[0m"
12
+
13
+
14
+ def generator(G):
15
+ s = title("OmniGenerator", "\033[95m") + "\n"
16
+
17
+ s += str(G.encoder) + "\n\n"
18
+ for d in G.decoders:
19
+ if d not in {"a", "t"}:
20
+ s += str(G.decoders[d]) + "\n\n"
21
+ elif d == "a":
22
+ s += "[r & s]\n" + str(G.decoders["a"]["r"]) + "\n\n"
23
+ else:
24
+ if G.opts.gen.t.use_bit_conditioning:
25
+ s += "[bit]\n" + str(G.decoders["t"]) + "\n\n"
26
+ else:
27
+ s += "[f & n]\n" + str(G.decoders["t"]["f"]) + "\n\n"
28
+ return s.strip()
29
+
30
+
31
+ def encoder(E):
32
+ s = title("Encoder") + "\n"
33
+ for b in E.model:
34
+ s += str(b) + "\n"
35
+ return s.strip()
36
+
37
+
38
+ def get_conv_weight(conv):
39
+ weight = torch.Tensor(
40
+ conv.out_channels, conv.in_channels // conv.groups, *conv.kernel_size
41
+ )
42
+ return weight.shape
43
+
44
+
45
+ def conv2dblock(obj):
46
+ name = "{:20}".format("Conv2dBlock")
47
+ s = ""
48
+ if "SpectralNorm" in obj.conv.__class__.__name__:
49
+ s = "SpectralNorm => "
50
+ w = str(tuple(get_conv_weight(obj.conv.module)))
51
+ else:
52
+ w = str(tuple(get_conv_weight(obj.conv)))
53
+ return f"{name}{s}{w}".strip()
54
+
55
+
56
+ def resblocks(rb):
57
+ s = "{}\n".format(f"ResBlocks({len(rb.model)})")
58
+ for i, r in enumerate(rb.model):
59
+ s += f" - ({i}) {str(r)}\n"
60
+ return s.strip()
61
+
62
+
63
+ def resblock(rb):
64
+ s = "{:12}".format("Resblock")
65
+ return f"{s}{rb.dim} channels, {rb.norm} norm + {rb.activation}"
66
+
67
+
68
+ def basedecoder(bd):
69
+ s = title(bd.__class__.__name__) + "\n"
70
+ for b in bd.model:
71
+ if isinstance(b, nn.Upsample) or "InterpolateNearest2d" in b.__class__.__name__:
72
+ s += "{:20}".format("Upsample") + "x2\n"
73
+ else:
74
+ s += str(b) + "\n"
75
+ return s.strip()
76
+
77
+
78
+ def spaderesblock(srb):
79
+ name = "{:20}".format("SPADEResnetBlock") + f"k {srb.kernel_size}, "
80
+ s = f"{name}{srb.fin} > {srb.fout}, "
81
+ s += f"param_free_norm: {srb.param_free_norm}, "
82
+ s += f"spectral_norm: {srb.use_spectral_norm}"
83
+ return s.strip()
84
+
85
+
86
+ def spadedecoder(sd):
87
+ s = title(sd.__class__.__name__) + "\n"
88
+ up = "{:20}x2\n".format("Upsample")
89
+ s += up
90
+ s += str(sd.head_0) + "\n"
91
+ s += up
92
+ s += str(sd.G_middle_0) + "\n"
93
+ s += up
94
+ s += str(sd.G_middle_1) + "\n"
95
+ for i, u in enumerate(sd.up_spades):
96
+ s += up
97
+ s += str(u) + "\n"
98
+ s += "{:20}".format("Conv2d") + str(tuple(get_conv_weight(sd.conv_img))) + " tanh"
99
+ return s
climategan/trainer.py ADDED
@@ -0,0 +1,1939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main component: the trainer handles everything:
3
+ * initializations
4
+ * training
5
+ * saving
6
+ """
7
+ import inspect
8
+ import warnings
9
+ from copy import deepcopy
10
+ from pathlib import Path
11
+ from time import time
12
+
13
+ import numpy as np
14
+ from comet_ml import ExistingExperiment, Experiment
15
+
16
+ warnings.simplefilter("ignore", UserWarning)
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from addict import Dict
21
+ from torch import autograd, sigmoid, softmax
22
+ from torch.cuda.amp import GradScaler, autocast
23
+ from tqdm import tqdm
24
+
25
+ from climategan.data import get_all_loaders
26
+ from climategan.discriminator import OmniDiscriminator, create_discriminator
27
+ from climategan.eval_metrics import accuracy, mIOU
28
+ from climategan.fid import compute_val_fid
29
+ from climategan.fire import add_fire
30
+ from climategan.generator import OmniGenerator, create_generator
31
+ from climategan.logger import Logger
32
+ from climategan.losses import get_losses
33
+ from climategan.optim import get_optimizer
34
+ from climategan.transforms import DiffTransforms
35
+ from climategan.tutils import (
36
+ divide_pred,
37
+ get_num_params,
38
+ get_WGAN_gradient,
39
+ lrgb2srgb,
40
+ normalize,
41
+ print_num_parameters,
42
+ shuffle_batch_tuple,
43
+ srgb2lrgb,
44
+ vgg_preprocess,
45
+ zero_grad,
46
+ )
47
+ from climategan.utils import (
48
+ comet_kwargs,
49
+ div_dict,
50
+ find_target_size,
51
+ flatten_opts,
52
+ get_display_indices,
53
+ get_existing_comet_id,
54
+ get_latest_opts,
55
+ merge,
56
+ resolve,
57
+ sum_dict,
58
+ Timer,
59
+ )
60
+
61
+ try:
62
+ import torch_xla.core.xla_model as xm # type: ignore
63
+ except ImportError:
64
+ pass
65
+
66
+
67
+ class Trainer:
68
+ """Main trainer class"""
69
+
70
+ def __init__(self, opts, comet_exp=None, verbose=0, device=None):
71
+ """Trainer class to gather various model training procedures
72
+ such as training evaluating saving and logging
73
+
74
+ init:
75
+ * creates an addict.Dict logger
76
+ * creates logger.exp as a comet_exp experiment if `comet` arg is True
77
+ * sets the device (1 GPU or CPU)
78
+
79
+ Args:
80
+ opts (addict.Dict): options to configure the trainer, the data, the models
81
+ comet (bool, optional): whether to log the trainer with comet.ml.
82
+ Defaults to False.
83
+ verbose (int, optional): printing level to debug. Defaults to 0.
84
+ """
85
+ super().__init__()
86
+
87
+ self.opts = opts
88
+ self.verbose = verbose
89
+ self.logger = Logger(self)
90
+
91
+ self.losses = None
92
+ self.G = self.D = None
93
+ self.real_val_fid_stats = None
94
+ self.use_pl4m = False
95
+ self.is_setup = False
96
+ self.loaders = self.all_loaders = None
97
+ self.exp = None
98
+
99
+ self.current_mode = "train"
100
+ self.diff_transforms = None
101
+ self.kitti_pretrain = self.opts.train.kitti.pretrain
102
+ self.pseudo_training_tasks = set(self.opts.train.pseudo.tasks)
103
+
104
+ self.lr_names = {}
105
+ self.base_display_images = {}
106
+ self.kitty_display_images = {}
107
+ self.domain_labels = {"s": 0, "r": 1}
108
+
109
+ self.device = device or torch.device(
110
+ "cuda:0" if torch.cuda.is_available() else "cpu"
111
+ )
112
+
113
+ if isinstance(comet_exp, Experiment):
114
+ self.exp = comet_exp
115
+
116
+ if self.opts.train.amp:
117
+ optimizers = [
118
+ self.opts.gen.opt.optimizer.lower(),
119
+ self.opts.dis.opt.optimizer.lower(),
120
+ ]
121
+ if "extraadam" in optimizers:
122
+ raise ValueError(
123
+ "AMP does not work with ExtraAdam ({})".format(optimizers)
124
+ )
125
+ self.grad_scaler_d = GradScaler()
126
+ self.grad_scaler_g = GradScaler()
127
+
128
+ # -------------------------------
129
+ # ----- Legacy Overwrites -----
130
+ # -------------------------------
131
+ if (
132
+ self.opts.gen.s.depth_feat_fusion is True
133
+ or self.opts.gen.s.depth_dada_fusion is True
134
+ ):
135
+ self.opts.gen.s.use_dada = True
136
+
137
+ @torch.no_grad()
138
+ def paint_and_mask(self, image_batch, mask_batch=None, resolution="approx"):
139
+ """
140
+ Paints a batch of images (or a single image with a batch dim of 1). If
141
+ masks are not provided, they are inferred from the masker.
142
+ Resolution can either be the train-time resolution or the closest
143
+ multiple of 2 ** spade_n_up
144
+
145
+ Operations performed without gradient
146
+
147
+ If resolution == "approx" then the output image has the shape:
148
+ (dim // 2 ** spade_n_up) * 2 ** spade_n_up, for dim in [height, width]
149
+ eg: (1000, 1300) => (896, 1280) for spade_n_up = 7
150
+ If resolution == "exact" then the output image has the same shape:
151
+ we first process in "approx" mode then upsample bilinear
152
+ If resolution == "basic" image output shape is the train-time's
153
+ (typically 640x640)
154
+ If resolution == "upsample" image is inferred as "basic" and
155
+ then upsampled to original size
156
+
157
+ Args:
158
+ image_batch (torch.Tensor): 4D batch of images to flood
159
+ mask_batch (torch.Tensor, optional): Masks for the images.
160
+ Defaults to None (infer with Masker).
161
+ resolution (str, optional): "approx", "exact" or False
162
+
163
+ Returns:
164
+ torch.Tensor: N x C x H x W where H and W depend on `resolution`
165
+ """
166
+ assert resolution in {"approx", "exact", "basic", "upsample"}
167
+ previous_mode = self.current_mode
168
+ if previous_mode == "train":
169
+ self.eval_mode()
170
+
171
+ if mask_batch is None:
172
+ mask_batch = self.G.mask(x=image_batch)
173
+ else:
174
+ assert len(image_batch) == len(mask_batch)
175
+ assert image_batch.shape[-2:] == mask_batch.shape[-2:]
176
+
177
+ if resolution not in {"approx", "exact"}:
178
+ painted = self.G.paint(mask_batch, image_batch)
179
+
180
+ if resolution == "upsample":
181
+ painted = nn.functional.interpolate(
182
+ painted, size=image_batch.shape[-2:], mode="bilinear"
183
+ )
184
+ else:
185
+ # save latent shape
186
+ zh = self.G.painter.z_h
187
+ zw = self.G.painter.z_w
188
+ # adapt latent shape to approximately keep the resolution
189
+ self.G.painter.z_h = (
190
+ image_batch.shape[-2] // 2**self.opts.gen.p.spade_n_up
191
+ )
192
+ self.G.painter.z_w = (
193
+ image_batch.shape[-1] // 2**self.opts.gen.p.spade_n_up
194
+ )
195
+
196
+ painted = self.G.paint(mask_batch, image_batch)
197
+
198
+ self.G.painter.z_h = zh
199
+ self.G.painter.z_w = zw
200
+ if resolution == "exact":
201
+ painted = nn.functional.interpolate(
202
+ painted, size=image_batch.shape[-2:], mode="bilinear"
203
+ )
204
+
205
+ if previous_mode == "train":
206
+ self.train_mode()
207
+
208
+ return painted
209
+
210
+ def _p(self, *args, **kwargs):
211
+ """
212
+ verbose-dependant print util
213
+ """
214
+ if self.verbose > 0:
215
+ print(*args, **kwargs)
216
+
217
+ @torch.no_grad()
218
+ def infer_all(
219
+ self,
220
+ x,
221
+ numpy=True,
222
+ stores={},
223
+ bin_value=-1,
224
+ half=False,
225
+ xla=False,
226
+ cloudy=False,
227
+ auto_resize_640=False,
228
+ ignore_event=set(),
229
+ return_masks=False,
230
+ ):
231
+ """
232
+ Create a dictionnary of events from a numpy or tensor,
233
+ single or batch image data.
234
+
235
+ stores is a dictionnary of times for the Timer class.
236
+
237
+ bin_value is used to binarize (or not) flood masks
238
+ """
239
+ assert self.is_setup
240
+ assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
241
+
242
+ # convert numpy to tensor
243
+ if not isinstance(x, torch.Tensor):
244
+ x = torch.tensor(x, device=self.device)
245
+
246
+ # add batch dimension
247
+ if len(x.shape) == 3:
248
+ x.unsqueeze_(0)
249
+
250
+ # permute channels as second dimension
251
+ if x.shape[1] != 3:
252
+ assert x.shape[-1] == 3, f"Unknown x shape to permute {x.shape}"
253
+ x = x.permute(0, 3, 1, 2)
254
+
255
+ # send to device
256
+ if x.device != self.device:
257
+ x = x.to(self.device)
258
+
259
+ # interpolate to standard input size
260
+ if auto_resize_640 and (x.shape[-1] != 640 or x.shape[-2] != 640):
261
+ x = torch.nn.functional.interpolate(x, (640, 640), mode="bilinear")
262
+
263
+ if half:
264
+ x = x.half()
265
+
266
+ # adjust painter's latent vector
267
+ self.G.painter.set_latent_shape(x.shape, True)
268
+
269
+ with Timer(store=stores.get("all events", [])):
270
+ # encode
271
+ with Timer(store=stores.get("encode", [])):
272
+ z = self.G.encode(x)
273
+ if xla:
274
+ xm.mark_step()
275
+
276
+ # predict from masker
277
+ with Timer(store=stores.get("depth", [])):
278
+ depth, z_depth = self.G.decoders["d"](z)
279
+ if xla:
280
+ xm.mark_step()
281
+ with Timer(store=stores.get("segmentation", [])):
282
+ segmentation = self.G.decoders["s"](z, z_depth)
283
+ if xla:
284
+ xm.mark_step()
285
+ with Timer(store=stores.get("mask", [])):
286
+ cond = self.G.make_m_cond(depth, segmentation, x)
287
+ mask = self.G.mask(z=z, cond=cond, z_depth=z_depth)
288
+ if xla:
289
+ xm.mark_step()
290
+
291
+ # apply events
292
+ if "wildfire" not in ignore_event:
293
+ with Timer(store=stores.get("wildfire", [])):
294
+ wildfire = self.compute_fire(x, seg_preds=segmentation)
295
+ if "smog" not in ignore_event:
296
+ with Timer(store=stores.get("smog", [])):
297
+ smog = self.compute_smog(x, d=depth, s=segmentation)
298
+ if "flood" not in ignore_event:
299
+ with Timer(store=stores.get("flood", [])):
300
+ flood = self.compute_flood(
301
+ x,
302
+ m=mask,
303
+ s=segmentation,
304
+ cloudy=cloudy,
305
+ bin_value=bin_value,
306
+ )
307
+
308
+ if xla:
309
+ xm.mark_step()
310
+
311
+ if numpy:
312
+ with Timer(store=stores.get("numpy", [])):
313
+ # normalize to 0-1
314
+ flood = normalize(flood).cpu()
315
+ smog = normalize(smog).cpu()
316
+ wildfire = normalize(wildfire).cpu()
317
+
318
+ # convert to numpy
319
+ flood = flood.permute(0, 2, 3, 1).numpy()
320
+ smog = smog.permute(0, 2, 3, 1).numpy()
321
+ wildfire = wildfire.permute(0, 2, 3, 1).numpy()
322
+
323
+ # convert to 0-255 uint8
324
+ flood = (flood * 255).astype(np.uint8)
325
+ smog = (smog * 255).astype(np.uint8)
326
+ wildfire = (wildfire * 255).astype(np.uint8)
327
+
328
+ output_data = {"flood": flood, "wildfire": wildfire, "smog": smog}
329
+ if return_masks:
330
+ output_data["mask"] = (
331
+ ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
332
+ )
333
+
334
+ return output_data
335
+
336
+ @classmethod
337
+ def resume_from_path(
338
+ cls,
339
+ path,
340
+ overrides={},
341
+ setup=True,
342
+ inference=False,
343
+ new_exp=False,
344
+ device=None,
345
+ verbose=1,
346
+ ):
347
+ """
348
+ Resume and optionally setup a trainer from a specific path,
349
+ using the latest opts and checkpoint. Requires path to contain opts.yaml
350
+ (or increased), url.txt (or increased) and checkpoints/
351
+
352
+ Args:
353
+ path (str | pathlib.Path): Trainer to resume
354
+ overrides (dict, optional): Override loaded opts with those. Defaults to {}.
355
+ setup (bool, optional): Wether or not to setup the trainer before
356
+ returning it. Defaults to True.
357
+ inference (bool, optional): Setup should be done in inference mode or not.
358
+ Defaults to False.
359
+ new_exp (bool, optional): Re-use existing comet exp in path or create
360
+ a new one? Defaults to False.
361
+ device (torch.device, optional): Device to use
362
+
363
+ Returns:
364
+ climategan.Trainer: Loaded and resumed trainer
365
+ """
366
+ p = resolve(path)
367
+ assert p.exists()
368
+
369
+ c = p / "checkpoints"
370
+ assert c.exists() and c.is_dir()
371
+
372
+ opts = get_latest_opts(p)
373
+ opts = Dict(merge(overrides, opts))
374
+ opts.train.resume = True
375
+
376
+ if new_exp is None:
377
+ exp = None
378
+ elif new_exp is True:
379
+ exp = Experiment(project_name="climategan", **comet_kwargs)
380
+ exp.log_asset_folder(
381
+ str(resolve(Path(__file__)).parent),
382
+ recursive=True,
383
+ log_file_name=True,
384
+ )
385
+ exp.log_parameters(flatten_opts(opts))
386
+ else:
387
+ comet_id = get_existing_comet_id(p)
388
+ exp = ExistingExperiment(previous_experiment=comet_id, **comet_kwargs)
389
+
390
+ trainer = cls(opts, comet_exp=exp, device=device, verbose=verbose)
391
+
392
+ if setup:
393
+ trainer.setup(inference=inference)
394
+ return trainer
395
+
396
+ def save(self):
397
+ save_dir = Path(self.opts.output_path) / Path("checkpoints")
398
+ save_dir.mkdir(exist_ok=True)
399
+ save_path = save_dir / "latest_ckpt.pth"
400
+
401
+ # Construct relevant state dicts / optims:
402
+ # Save at least G
403
+ save_dict = {
404
+ "epoch": self.logger.epoch,
405
+ "G": self.G.state_dict(),
406
+ "g_opt": self.g_opt.state_dict(),
407
+ "step": self.logger.global_step,
408
+ }
409
+
410
+ if self.D is not None and get_num_params(self.D) > 0:
411
+ save_dict["D"] = self.D.state_dict()
412
+ save_dict["d_opt"] = self.d_opt.state_dict()
413
+
414
+ if (
415
+ self.logger.epoch >= self.opts.train.min_save_epoch
416
+ and self.logger.epoch % self.opts.train.save_n_epochs == 0
417
+ ):
418
+ torch.save(save_dict, save_dir / f"epoch_{self.logger.epoch}_ckpt.pth")
419
+
420
+ torch.save(save_dict, save_path)
421
+
422
+ def resume(self, inference=False):
423
+ tpu = "xla" in str(self.device)
424
+ if tpu:
425
+ print("Resuming on TPU:", self.device)
426
+
427
+ m_path = Path(self.opts.load_paths.m)
428
+ p_path = Path(self.opts.load_paths.p)
429
+ pm_path = Path(self.opts.load_paths.pm)
430
+ output_path = Path(self.opts.output_path)
431
+
432
+ map_loc = self.device if not tpu else "cpu"
433
+
434
+ if "m" in self.opts.tasks and "p" in self.opts.tasks:
435
+ # ----------------------------------------
436
+ # ----- Masker and Painter Loading -----
437
+ # ----------------------------------------
438
+
439
+ # want to resume a pm model but no path was provided:
440
+ # resume a single pm model from output_path
441
+ if all([str(p) == "none" for p in [m_path, p_path, pm_path]]):
442
+ checkpoint_path = output_path / "checkpoints/latest_ckpt.pth"
443
+ print("Resuming P+M model from", str(checkpoint_path))
444
+ checkpoint = torch.load(checkpoint_path, map_location=map_loc)
445
+
446
+ # want to resume a pm model with a pm_path provided:
447
+ # resume a single pm model from load_paths.pm
448
+ # depending on whether a dir or a file is specified
449
+ elif str(pm_path) != "none":
450
+ assert pm_path.exists()
451
+
452
+ if pm_path.is_dir():
453
+ checkpoint_path = pm_path / "checkpoints/latest_ckpt.pth"
454
+ else:
455
+ assert pm_path.suffix == ".pth"
456
+ checkpoint_path = pm_path
457
+
458
+ print("Resuming P+M model from", str(checkpoint_path))
459
+ checkpoint = torch.load(checkpoint_path, map_location=map_loc)
460
+
461
+ # want to resume a pm model, pm_path not provided:
462
+ # m_path and p_path must be provided as dirs or pth files
463
+ elif m_path != p_path:
464
+ assert m_path.exists()
465
+ assert p_path.exists()
466
+
467
+ if m_path.is_dir():
468
+ m_path = m_path / "checkpoints/latest_ckpt.pth"
469
+
470
+ if p_path.is_dir():
471
+ p_path = p_path / "checkpoints/latest_ckpt.pth"
472
+
473
+ assert m_path.suffix == ".pth"
474
+ assert p_path.suffix == ".pth"
475
+
476
+ print(f"Resuming P+M model from \n -{p_path} \nand \n -{m_path}")
477
+ m_checkpoint = torch.load(m_path, map_location=map_loc)
478
+ p_checkpoint = torch.load(p_path, map_location=map_loc)
479
+ checkpoint = merge(m_checkpoint, p_checkpoint)
480
+
481
+ else:
482
+ raise ValueError(
483
+ "Cannot resume a P+M model with provided load_paths:\n{}".format(
484
+ self.opts.load_paths
485
+ )
486
+ )
487
+
488
+ else:
489
+ # ----------------------------------
490
+ # ----- Single Model Loading -----
491
+ # ----------------------------------
492
+
493
+ # cannot specify both paths
494
+ if str(m_path) != "none" and str(p_path) != "none":
495
+ raise ValueError(
496
+ "Opts tasks are {} but received 2 values for the load_paths".format(
497
+ self.opts.tasks
498
+ )
499
+ )
500
+
501
+ # specified m
502
+ elif str(m_path) != "none":
503
+ assert m_path.exists()
504
+ assert "m" in self.opts.tasks
505
+ model = "M"
506
+ if m_path.is_dir():
507
+ m_path = m_path / "checkpoints/latest_ckpt.pth"
508
+ checkpoint_path = m_path
509
+
510
+ # specified m
511
+ elif str(p_path) != "none":
512
+ assert p_path.exists()
513
+ assert "p" in self.opts.tasks
514
+ model = "P"
515
+ if p_path.is_dir():
516
+ p_path = p_path / "checkpoints/latest_ckpt.pth"
517
+ checkpoint_path = p_path
518
+
519
+ # specified neither p nor m: resume from output_path
520
+ else:
521
+ model = "P" if "p" in self.opts.tasks else "M"
522
+ checkpoint_path = output_path / "checkpoints/latest_ckpt.pth"
523
+
524
+ print(f"Resuming {model} model from {checkpoint_path}")
525
+ checkpoint = torch.load(checkpoint_path, map_location=map_loc)
526
+
527
+ # On TPUs must send the data to the xla device as it cannot be mapped
528
+ # there directly from torch.load
529
+ if tpu:
530
+ checkpoint = xm.send_cpu_data_to_device(checkpoint, self.device)
531
+
532
+ # -----------------------
533
+ # ----- Restore G -----
534
+ # -----------------------
535
+ if inference:
536
+ incompatible_keys = self.G.load_state_dict(checkpoint["G"], strict=False)
537
+ if incompatible_keys.missing_keys:
538
+ print("WARNING: Missing keys in self.G.load_state_dict, keeping inits")
539
+ print(incompatible_keys.missing_keys)
540
+ if incompatible_keys.unexpected_keys:
541
+ print("WARNING: Ignoring Unexpected keys in self.G.load_state_dict")
542
+ print(incompatible_keys.unexpected_keys)
543
+ else:
544
+ self.G.load_state_dict(checkpoint["G"])
545
+
546
+ if inference:
547
+ # only G is needed to infer
548
+ print("Done loading checkpoints.")
549
+ return
550
+
551
+ self.g_opt.load_state_dict(checkpoint["g_opt"])
552
+
553
+ # ------------------------------
554
+ # ----- Resume scheduler -----
555
+ # ------------------------------
556
+ # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
557
+ for _ in range(self.logger.epoch + 1):
558
+ self.update_learning_rates()
559
+
560
+ # -----------------------
561
+ # ----- Restore D -----
562
+ # -----------------------
563
+ if self.D is not None and get_num_params(self.D) > 0:
564
+ self.D.load_state_dict(checkpoint["D"])
565
+ self.d_opt.load_state_dict(checkpoint["d_opt"])
566
+
567
+ # ---------------------------
568
+ # ----- Resore logger -----
569
+ # ---------------------------
570
+ self.logger.epoch = checkpoint["epoch"]
571
+ self.logger.global_step = checkpoint["step"]
572
+ self.exp.log_text(
573
+ "Resuming from epoch {} & step {}".format(
574
+ checkpoint["epoch"], checkpoint["step"]
575
+ )
576
+ )
577
+ # Round step to even number for extraGradient
578
+ if self.logger.global_step % 2 != 0:
579
+ self.logger.global_step += 1
580
+
581
+ def eval_mode(self):
582
+ """
583
+ Set trainer's models in eval mode
584
+ """
585
+ if self.G is not None:
586
+ self.G.eval()
587
+ if self.D is not None:
588
+ self.D.eval()
589
+ self.current_mode = "eval"
590
+
591
+ def train_mode(self):
592
+ """
593
+ Set trainer's models in train mode
594
+ """
595
+ if self.G is not None:
596
+ self.G.train()
597
+ if self.D is not None:
598
+ self.D.train()
599
+
600
+ self.current_mode = "train"
601
+
602
+ def assert_z_matches_x(self, x, z):
603
+ assert x.shape[0] == (
604
+ z.shape[0] if not isinstance(z, (list, tuple)) else z[0].shape[0]
605
+ ), "x-> {}, z->{}".format(
606
+ x.shape, z.shape if not isinstance(z, (list, tuple)) else z[0].shape
607
+ )
608
+
609
+ def batch_to_device(self, b):
610
+ """sends the data in b to self.device
611
+
612
+ Args:
613
+ b (dict): the batch dictionnay
614
+
615
+ Returns:
616
+ dict: the batch dictionnary with its "data" field sent to self.device
617
+ """
618
+ for task, tensor in b["data"].items():
619
+ b["data"][task] = tensor.to(self.device)
620
+ return b
621
+
622
+ def sample_painter_z(self, batch_size):
623
+ return self.G.sample_painter_z(batch_size, self.device)
624
+
625
+ @property
626
+ def train_loaders(self):
627
+ """Get a zip of all training loaders
628
+
629
+ Returns:
630
+ generator: zip generator yielding tuples:
631
+ (batch_rf, batch_rn, batch_sf, batch_sn)
632
+ """
633
+ return zip(*list(self.loaders["train"].values()))
634
+
635
+ @property
636
+ def val_loaders(self):
637
+ """Get a zip of all validation loaders
638
+
639
+ Returns:
640
+ generator: zip generator yielding tuples:
641
+ (batch_rf, batch_rn, batch_sf, batch_sn)
642
+ """
643
+ return zip(*list(self.loaders["val"].values()))
644
+
645
+ def compute_latent_shape(self):
646
+ """Compute the latent shape, i.e. the Encoder's output shape,
647
+ from a batch.
648
+
649
+ Raises:
650
+ ValueError: If no loader, the latent_shape cannot be inferred
651
+
652
+ Returns:
653
+ tuple: (c, h, w)
654
+ """
655
+ x = None
656
+ for mode in self.all_loaders:
657
+ for domain in self.all_loaders.loaders[mode]:
658
+ x = (
659
+ self.all_loaders[mode][domain]
660
+ .dataset[0]["data"]["x"]
661
+ .to(self.device)
662
+ )
663
+ break
664
+ if x is not None:
665
+ break
666
+
667
+ if x is None:
668
+ raise ValueError("No batch found to compute_latent_shape")
669
+
670
+ x = x.unsqueeze(0)
671
+ z = self.G.encode(x)
672
+ return z.shape[1:] if not isinstance(z, (list, tuple)) else z[0].shape[1:]
673
+
674
+ def g_opt_step(self):
675
+ """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation
676
+ step every other step
677
+ """
678
+ if "extra" in self.opts.gen.opt.optimizer.lower() and (
679
+ self.logger.global_step % 2 == 0
680
+ ):
681
+ self.g_opt.extrapolation()
682
+ else:
683
+ self.g_opt.step()
684
+
685
+ def d_opt_step(self):
686
+ """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation
687
+ step every other step
688
+ """
689
+ if "extra" in self.opts.dis.opt.optimizer.lower() and (
690
+ self.logger.global_step % 2 == 0
691
+ ):
692
+ self.d_opt.extrapolation()
693
+ else:
694
+ self.d_opt.step()
695
+
696
+ def update_learning_rates(self):
697
+ if self.g_scheduler is not None:
698
+ self.g_scheduler.step()
699
+ if self.d_scheduler is not None:
700
+ self.d_scheduler.step()
701
+
702
+ def setup(self, inference=False):
703
+ """Prepare the trainer before it can be used to train the models:
704
+ * initialize G and D
705
+ * creates 2 optimizers
706
+ """
707
+ self.logger.global_step = 0
708
+ start_time = time()
709
+ self.logger.time.start_time = start_time
710
+ verbose = self.verbose
711
+
712
+ if not inference:
713
+ self.all_loaders = get_all_loaders(self.opts)
714
+
715
+ # -----------------------
716
+ # ----- Generator -----
717
+ # -----------------------
718
+ __t = time()
719
+ print("Creating generator...")
720
+
721
+ self.G: OmniGenerator = create_generator(
722
+ self.opts, device=self.device, no_init=inference, verbose=verbose
723
+ )
724
+
725
+ self.has_painter = get_num_params(self.G.painter) or self.G.load_val_painter()
726
+
727
+ if self.has_painter:
728
+ self.G.painter.set_latent_shape(find_target_size(self.opts, "x"), True)
729
+
730
+ print(f"Generator OK in {time() - __t:.1f}s.")
731
+
732
+ if inference: # Inference mode: no more than a Generator needed
733
+ print("Inference mode: no Discriminator, no optimizers")
734
+ print_num_parameters(self)
735
+ self.switch_data(to="base")
736
+ if self.opts.train.resume:
737
+ self.resume(True)
738
+ self.eval_mode()
739
+ print("Trainer is in evaluation mode.")
740
+ print("Setup done.")
741
+ self.is_setup = True
742
+ return
743
+
744
+ # ---------------------------
745
+ # ----- Discriminator -----
746
+ # ---------------------------
747
+
748
+ self.D: OmniDiscriminator = create_discriminator(
749
+ self.opts, self.device, verbose=verbose
750
+ )
751
+ print("Discriminator OK.")
752
+
753
+ print_num_parameters(self)
754
+
755
+ # --------------------------
756
+ # ----- Optimization -----
757
+ # --------------------------
758
+ # Get different optimizers for each task (different learning rates)
759
+ self.g_opt, self.g_scheduler, self.lr_names["G"] = get_optimizer(
760
+ self.G, self.opts.gen.opt, self.opts.tasks
761
+ )
762
+
763
+ if get_num_params(self.D) > 0:
764
+ self.d_opt, self.d_scheduler, self.lr_names["D"] = get_optimizer(
765
+ self.D, self.opts.dis.opt, self.opts.tasks, True
766
+ )
767
+ else:
768
+ self.d_opt, self.d_scheduler = None, None
769
+
770
+ self.losses = get_losses(self.opts, verbose, device=self.device)
771
+
772
+ if "p" in self.opts.tasks and self.opts.gen.p.diff_aug.use:
773
+ self.diff_transforms = DiffTransforms(self.opts.gen.p.diff_aug)
774
+
775
+ if verbose > 0:
776
+ for mode, mode_dict in self.all_loaders.items():
777
+ for domain, domain_loader in mode_dict.items():
778
+ print(
779
+ "Loader {} {} : {}".format(
780
+ mode, domain, len(domain_loader.dataset)
781
+ )
782
+ )
783
+
784
+ # ----------------------------
785
+ # ----- Display images -----
786
+ # ----------------------------
787
+ self.set_display_images()
788
+
789
+ # -------------------------------
790
+ # ----- Log Architectures -----
791
+ # -------------------------------
792
+ self.logger.log_architecture()
793
+
794
+ # -----------------------------
795
+ # ----- Set data source -----
796
+ # -----------------------------
797
+ if self.kitti_pretrain:
798
+ self.switch_data(to="kitti")
799
+ else:
800
+ self.switch_data(to="base")
801
+
802
+ # -------------------------
803
+ # ----- Setup Done. -----
804
+ # -------------------------
805
+ print(" " * 50, end="\r")
806
+ print("Done creating display images")
807
+
808
+ if self.opts.train.resume:
809
+ print("Resuming Model (inference: False)")
810
+ self.resume(False)
811
+ else:
812
+ print("Not resuming: starting a new model")
813
+
814
+ print("Setup done.")
815
+ self.is_setup = True
816
+
817
+ def switch_data(self, to="kitti"):
818
+ caller = inspect.stack()[1].function
819
+ print(f"[{caller}] Switching data source to", to)
820
+ self.data_source = to
821
+ if to == "kitti":
822
+ self.display_images = self.kitty_display_images
823
+ if self.all_loaders is not None:
824
+ self.loaders = {
825
+ mode: {"s": self.all_loaders[mode]["kitti"]}
826
+ for mode in self.all_loaders
827
+ }
828
+ else:
829
+ self.display_images = self.base_display_images
830
+ if self.all_loaders is not None:
831
+ self.loaders = {
832
+ mode: {
833
+ domain: self.all_loaders[mode][domain]
834
+ for domain in self.all_loaders[mode]
835
+ if domain != "kitti"
836
+ }
837
+ for mode in self.all_loaders
838
+ }
839
+ if (
840
+ self.logger.global_step % 2 != 0
841
+ and "extra" in self.opts.dis.opt.optimizer.lower()
842
+ ):
843
+ print(
844
+ "Warning: artificially bumping step to run an extrapolation step first."
845
+ )
846
+ self.logger.global_step += 1
847
+
848
+ def set_display_images(self, use_all=False):
849
+ for mode, mode_dict in self.all_loaders.items():
850
+
851
+ if self.kitti_pretrain:
852
+ self.kitty_display_images[mode] = {}
853
+ self.base_display_images[mode] = {}
854
+
855
+ for domain in mode_dict:
856
+
857
+ if self.kitti_pretrain and domain == "kitti":
858
+ target_dict = self.kitty_display_images
859
+ else:
860
+ if domain == "kitti":
861
+ continue
862
+ target_dict = self.base_display_images
863
+
864
+ dataset = self.all_loaders[mode][domain].dataset
865
+ display_indices = (
866
+ get_display_indices(self.opts, domain, len(dataset))
867
+ if not use_all
868
+ else list(range(len(dataset)))
869
+ )
870
+ ldis = len(display_indices)
871
+ print(
872
+ f" Creating {ldis} {mode} {domain} display images...",
873
+ end="\r",
874
+ flush=True,
875
+ )
876
+ target_dict[mode][domain] = [
877
+ Dict(dataset[i])
878
+ for i in display_indices
879
+ if (print(f"({i})", end="\r") is None and i < len(dataset))
880
+ ]
881
+ if self.exp is not None:
882
+ for im_id, d in enumerate(target_dict[mode][domain]):
883
+ self.exp.log_parameter(
884
+ "display_image_{}_{}_{}".format(mode, domain, im_id),
885
+ d["paths"],
886
+ )
887
+
888
+ def train(self):
889
+ """For each epoch:
890
+ * train
891
+ * eval
892
+ * save
893
+ """
894
+ assert self.is_setup
895
+
896
+ for self.logger.epoch in range(
897
+ self.logger.epoch, self.logger.epoch + self.opts.train.epochs
898
+ ):
899
+ # backprop painter's disc loss to masker
900
+ if (
901
+ self.logger.epoch == self.opts.gen.p.pl4m_epoch
902
+ and get_num_params(self.G.painter) > 0
903
+ and "p" in self.opts.tasks
904
+ and self.opts.gen.m.use_pl4m
905
+ ):
906
+ print(
907
+ "\n\n >>> Enabling pl4m at epoch {}\n\n".format(self.logger.epoch)
908
+ )
909
+ self.use_pl4m = True
910
+
911
+ self.run_epoch()
912
+ self.run_evaluation(verbose=1)
913
+ self.save()
914
+
915
+ # end vkitti2 pre-training
916
+ if self.logger.epoch == self.opts.train.kitti.epochs - 1:
917
+ self.switch_data(to="base")
918
+ self.kitti_pretrain = False
919
+
920
+ # end pseudo training
921
+ if self.logger.epoch == self.opts.train.pseudo.epochs - 1:
922
+ self.pseudo_training_tasks = set()
923
+
924
+ def run_epoch(self):
925
+ """Runs an epoch:
926
+ * checks trainer is setup
927
+ * gets a tuple of batches per domain
928
+ * sends batches to device
929
+ * updates sequentially G, D
930
+ """
931
+ assert self.is_setup
932
+ self.train_mode()
933
+ if self.exp is not None:
934
+ self.exp.log_parameter("epoch", self.logger.epoch)
935
+ epoch_len = min(len(loader) for loader in self.loaders["train"].values())
936
+ epoch_desc = "Epoch {}".format(self.logger.epoch)
937
+ self.logger.time.epoch_start = time()
938
+
939
+ for multi_batch_tuple in tqdm(
940
+ self.train_loaders,
941
+ desc=epoch_desc,
942
+ total=epoch_len,
943
+ mininterval=0.5,
944
+ unit="batch",
945
+ ):
946
+
947
+ self.logger.time.step_start = time()
948
+ multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple)
949
+
950
+ # The `[0]` is because the domain is contained in a list
951
+ multi_domain_batch = {
952
+ batch["domain"][0]: self.batch_to_device(batch)
953
+ for batch in multi_batch_tuple
954
+ }
955
+ # ------------------------------
956
+ # ----- Update Generator -----
957
+ # ------------------------------
958
+
959
+ # freeze params of the discriminator
960
+ if self.d_opt is not None:
961
+ for param in self.D.parameters():
962
+ param.requires_grad = False
963
+
964
+ self.update_G(multi_domain_batch)
965
+
966
+ # ----------------------------------
967
+ # ----- Update Discriminator -----
968
+ # ----------------------------------
969
+
970
+ # unfreeze params of the discriminator
971
+ if self.d_opt is not None and not self.kitti_pretrain:
972
+ for param in self.D.parameters():
973
+ param.requires_grad = True
974
+
975
+ self.update_D(multi_domain_batch)
976
+
977
+ # -------------------------
978
+ # ----- Log Metrics -----
979
+ # -------------------------
980
+ self.logger.global_step += 1
981
+ self.logger.log_step_time(time())
982
+
983
+ if not self.kitti_pretrain:
984
+ self.update_learning_rates()
985
+
986
+ self.logger.log_learning_rates()
987
+ self.logger.log_epoch_time(time())
988
+
989
+ def update_G(self, multi_domain_batch, verbose=0):
990
+ """Perform an update on g from multi_domain_batch which is a dictionary
991
+ domain => batch
992
+
993
+ * automatic mixed precision according to self.opts.train.amp
994
+ * compute loss for each task
995
+ * loss.backward()
996
+ * g_opt_step()
997
+ * g_opt.step() or .extrapolation() depending on self.logger.global_step
998
+ * logs losses on comet.ml with self.logger.log_losses(model_to_update="G")
999
+
1000
+ Args:
1001
+ multi_domain_batch (dict): dictionnary of domain batches
1002
+ """
1003
+ zero_grad(self.G)
1004
+ if self.opts.train.amp:
1005
+ with autocast():
1006
+ g_loss = self.get_G_loss(multi_domain_batch, verbose)
1007
+ self.grad_scaler_g.scale(g_loss).backward()
1008
+ self.grad_scaler_g.step(self.g_opt)
1009
+ self.grad_scaler_g.update()
1010
+ else:
1011
+ g_loss = self.get_G_loss(multi_domain_batch, verbose)
1012
+ g_loss.backward()
1013
+ self.g_opt_step()
1014
+
1015
+ self.logger.log_losses(model_to_update="G", mode="train")
1016
+
1017
+ def update_D(self, multi_domain_batch, verbose=0):
1018
+ zero_grad(self.D)
1019
+
1020
+ if self.opts.train.amp:
1021
+ with autocast():
1022
+ d_loss = self.get_D_loss(multi_domain_batch, verbose)
1023
+ self.grad_scaler_d.scale(d_loss).backward()
1024
+ self.grad_scaler_d.step(self.d_opt)
1025
+ self.grad_scaler_d.update()
1026
+ else:
1027
+ d_loss = self.get_D_loss(multi_domain_batch, verbose)
1028
+ d_loss.backward()
1029
+ self.d_opt_step()
1030
+
1031
+ self.logger.losses.disc.total_loss = d_loss.item()
1032
+ self.logger.log_losses(model_to_update="D", mode="train")
1033
+
1034
+ def get_D_loss(self, multi_domain_batch, verbose=0):
1035
+ """Compute the discriminators' losses:
1036
+
1037
+ * for each domain-specific batch:
1038
+ * encode the image
1039
+ * get the conditioning tensor if using spade
1040
+ * source domain is the data's domain, sequentially r|s then f|n
1041
+ * get the target domain accordingly
1042
+ * compute the translated image from the data
1043
+ * compute the source domain discriminator's loss on the data
1044
+ * compute the target domain discriminator's loss on the translated image
1045
+
1046
+ # ? In this setting, each D[decoder][domain] is updated twice towards
1047
+ # real or fake data
1048
+
1049
+ See readme's update d section for details
1050
+
1051
+ Args:
1052
+ multi_domain_batch ([type]): [description]
1053
+
1054
+ Returns:
1055
+ [type]: [description]
1056
+ """
1057
+
1058
+ disc_loss = {
1059
+ "m": {"Advent": 0},
1060
+ "s": {"Advent": 0},
1061
+ }
1062
+ if self.opts.dis.p.use_local_discriminator:
1063
+ disc_loss["p"] = {"global": 0, "local": 0}
1064
+ else:
1065
+ disc_loss["p"] = {"gan": 0}
1066
+
1067
+ for domain, batch in multi_domain_batch.items():
1068
+ x = batch["data"]["x"]
1069
+
1070
+ # ---------------------
1071
+ # ----- Painter -----
1072
+ # ---------------------
1073
+ if domain == "rf" and self.has_painter:
1074
+ m = batch["data"]["m"]
1075
+ # sample vector
1076
+ with torch.no_grad():
1077
+ # see spade compute_discriminator_loss
1078
+ fake = self.G.paint(m, x)
1079
+ if self.opts.gen.p.diff_aug.use:
1080
+ fake = self.diff_transforms(fake)
1081
+ x = self.diff_transforms(x)
1082
+ fake = fake.detach()
1083
+ fake.requires_grad_()
1084
+
1085
+ if self.opts.dis.p.use_local_discriminator:
1086
+ fake_d_global = self.D["p"]["global"](fake)
1087
+ real_d_global = self.D["p"]["global"](x)
1088
+
1089
+ fake_d_local = self.D["p"]["local"](fake * m)
1090
+ real_d_local = self.D["p"]["local"](x * m)
1091
+
1092
+ global_loss = self.losses["D"]["p"](fake_d_global, False, True)
1093
+ global_loss += self.losses["D"]["p"](real_d_global, True, True)
1094
+
1095
+ local_loss = self.losses["D"]["p"](fake_d_local, False, True)
1096
+ local_loss += self.losses["D"]["p"](real_d_local, True, True)
1097
+
1098
+ disc_loss["p"]["global"] += global_loss
1099
+ disc_loss["p"]["local"] += local_loss
1100
+ else:
1101
+ real_cat = torch.cat([m, x], axis=1)
1102
+ fake_cat = torch.cat([m, fake], axis=1)
1103
+ real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
1104
+ real_fake_d = self.D["p"](real_fake_cat)
1105
+ real_d, fake_d = divide_pred(real_fake_d)
1106
+ disc_loss["p"]["gan"] = self.losses["D"]["p"](fake_d, False, True)
1107
+ disc_loss["p"]["gan"] += self.losses["D"]["p"](real_d, True, True)
1108
+
1109
+ # --------------------
1110
+ # ----- Masker -----
1111
+ # --------------------
1112
+ else:
1113
+ z = self.G.encode(x)
1114
+ s_pred = d_pred = cond = z_depth = None
1115
+
1116
+ if "s" in batch["data"]:
1117
+ if "d" in self.opts.tasks and self.opts.gen.s.use_dada:
1118
+ d_pred, z_depth = self.G.decoders["d"](z)
1119
+
1120
+ step_loss, s_pred = self.masker_s_loss(
1121
+ x, z, d_pred, z_depth, None, domain, for_="D"
1122
+ )
1123
+ step_loss *= self.opts.train.lambdas.advent.adv_main
1124
+ disc_loss["s"]["Advent"] += step_loss
1125
+
1126
+ if "m" in batch["data"]:
1127
+ if "d" in self.opts.tasks:
1128
+ if self.opts.gen.m.use_spade:
1129
+ if d_pred is None:
1130
+ d_pred, z_depth = self.G.decoders["d"](z)
1131
+ cond = self.G.make_m_cond(d_pred, s_pred, x)
1132
+ elif self.opts.gen.m.use_dada:
1133
+ if d_pred is None:
1134
+ d_pred, z_depth = self.G.decoders["d"](z)
1135
+
1136
+ step_loss, _ = self.masker_m_loss(
1137
+ x,
1138
+ z,
1139
+ None,
1140
+ domain,
1141
+ for_="D",
1142
+ cond=cond,
1143
+ z_depth=z_depth,
1144
+ depth_preds=d_pred,
1145
+ )
1146
+ step_loss *= self.opts.train.lambdas.advent.adv_main
1147
+ disc_loss["m"]["Advent"] += step_loss
1148
+
1149
+ self.logger.losses.disc.update(
1150
+ {
1151
+ dom: {
1152
+ k: v.item() if isinstance(v, torch.Tensor) else v
1153
+ for k, v in d.items()
1154
+ }
1155
+ for dom, d in disc_loss.items()
1156
+ }
1157
+ )
1158
+
1159
+ loss = sum(v for d in disc_loss.values() for k, v in d.items())
1160
+ return loss
1161
+
1162
+ def get_G_loss(self, multi_domain_batch, verbose=0):
1163
+ m_loss = p_loss = None
1164
+
1165
+ # For now, always compute "representation loss"
1166
+ g_loss = 0
1167
+
1168
+ if any(t in self.opts.tasks for t in "msd"):
1169
+ m_loss = self.get_masker_loss(multi_domain_batch)
1170
+ self.logger.losses.gen.masker = m_loss.item()
1171
+ g_loss += m_loss
1172
+
1173
+ if "p" in self.opts.tasks and not self.kitti_pretrain:
1174
+ p_loss = self.get_painter_loss(multi_domain_batch)
1175
+ self.logger.losses.gen.painter = p_loss.item()
1176
+ g_loss += p_loss
1177
+
1178
+ assert g_loss != 0 and not isinstance(g_loss, int), "No update in get_G_loss!"
1179
+
1180
+ self.logger.losses.gen.total_loss = g_loss.item()
1181
+
1182
+ return g_loss
1183
+
1184
+ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings
1185
+ """Only update the representation part of the model, meaning everything
1186
+ but the translation part
1187
+
1188
+ * for each batch in available domains:
1189
+ * compute task-specific losses
1190
+ * compute the adaptation and translation decoders' auto-encoding losses
1191
+ * compute the adaptation decoder's translation losses (GAN and Cycle)
1192
+
1193
+ Args:
1194
+ multi_domain_batch (dict): dictionnary mapping domain names to batches from
1195
+ the trainer's loaders
1196
+
1197
+ Returns:
1198
+ torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas
1199
+ """
1200
+ m_loss = 0
1201
+ for domain, batch in multi_domain_batch.items():
1202
+ # We don't care about the flooded domain here
1203
+ if domain == "rf":
1204
+ continue
1205
+
1206
+ x = batch["data"]["x"]
1207
+ z = self.G.encode(x)
1208
+
1209
+ # --------------------------------------
1210
+ # ----- task-specific losses (2) -----
1211
+ # --------------------------------------
1212
+ d_pred = s_pred = z_depth = None
1213
+ for task in ["d", "s", "m"]:
1214
+ if task not in batch["data"]:
1215
+ continue
1216
+
1217
+ target = batch["data"][task]
1218
+
1219
+ if task == "d":
1220
+ loss, d_pred, z_depth = self.masker_d_loss(
1221
+ x, z, target, domain, "G"
1222
+ )
1223
+ m_loss += loss
1224
+ self.logger.losses.gen.task["d"][domain] = loss.item()
1225
+
1226
+ elif task == "s":
1227
+ loss, s_pred = self.masker_s_loss(
1228
+ x, z, d_pred, z_depth, target, domain, "G"
1229
+ )
1230
+ m_loss += loss
1231
+ self.logger.losses.gen.task["s"][domain] = loss.item()
1232
+
1233
+ elif task == "m":
1234
+ cond = None
1235
+ if self.opts.gen.m.use_spade:
1236
+ if not self.opts.gen.m.detach:
1237
+ d_pred = d_pred.clone()
1238
+ s_pred = s_pred.clone()
1239
+ cond = self.G.make_m_cond(d_pred, s_pred, x)
1240
+
1241
+ loss, _ = self.masker_m_loss(
1242
+ x,
1243
+ z,
1244
+ target,
1245
+ domain,
1246
+ "G",
1247
+ cond=cond,
1248
+ z_depth=z_depth,
1249
+ depth_preds=d_pred,
1250
+ )
1251
+ m_loss += loss
1252
+ self.logger.losses.gen.task["m"][domain] = loss.item()
1253
+
1254
+ return m_loss
1255
+
1256
+ def get_painter_loss(self, multi_domain_batch):
1257
+ """Computes the translation loss when flooding/deflooding images
1258
+
1259
+ Args:
1260
+ multi_domain_batch (dict): dictionnary mapping domain names to batches from
1261
+ the trainer's loaders
1262
+
1263
+ Returns:
1264
+ torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas
1265
+ """
1266
+ step_loss = 0
1267
+ # self.g_opt.zero_grad()
1268
+ lambdas = self.opts.train.lambdas
1269
+ batch_domain = "rf"
1270
+ batch = multi_domain_batch[batch_domain]
1271
+
1272
+ x = batch["data"]["x"]
1273
+ # ! different mask: hides water to be reconstructed
1274
+ # ! 1 for water, 0 otherwise
1275
+ m = batch["data"]["m"]
1276
+ fake_flooded = self.G.paint(m, x)
1277
+
1278
+ # ----------------------
1279
+ # ----- VGG Loss -----
1280
+ # ----------------------
1281
+ if lambdas.G.p.vgg != 0:
1282
+ loss = self.losses["G"]["p"]["vgg"](
1283
+ vgg_preprocess(fake_flooded * m), vgg_preprocess(x * m)
1284
+ )
1285
+ loss *= lambdas.G.p.vgg
1286
+ self.logger.losses.gen.p.vgg = loss.item()
1287
+ step_loss += loss
1288
+
1289
+ # ---------------------
1290
+ # ----- TV Loss -----
1291
+ # ---------------------
1292
+ if lambdas.G.p.tv != 0:
1293
+ loss = self.losses["G"]["p"]["tv"](fake_flooded * m)
1294
+ loss *= lambdas.G.p.tv
1295
+ self.logger.losses.gen.p.tv = loss.item()
1296
+ step_loss += loss
1297
+
1298
+ # --------------------------
1299
+ # ----- Context Loss -----
1300
+ # --------------------------
1301
+ if lambdas.G.p.context != 0:
1302
+ loss = self.losses["G"]["p"]["context"](fake_flooded, x, m)
1303
+ loss *= lambdas.G.p.context
1304
+ self.logger.losses.gen.p.context = loss.item()
1305
+ step_loss += loss
1306
+
1307
+ # ---------------------------------
1308
+ # ----- Reconstruction Loss -----
1309
+ # ---------------------------------
1310
+ if lambdas.G.p.reconstruction != 0:
1311
+ loss = self.losses["G"]["p"]["reconstruction"](fake_flooded, x, m)
1312
+ loss *= lambdas.G.p.reconstruction
1313
+ self.logger.losses.gen.p.reconstruction = loss.item()
1314
+ step_loss += loss
1315
+
1316
+ # -------------------------------------
1317
+ # ----- Local & Global GAN Loss -----
1318
+ # -------------------------------------
1319
+ if self.opts.gen.p.diff_aug.use:
1320
+ fake_flooded = self.diff_transforms(fake_flooded)
1321
+ x = self.diff_transforms(x)
1322
+
1323
+ if self.opts.dis.p.use_local_discriminator:
1324
+ fake_d_global = self.D["p"]["global"](fake_flooded)
1325
+ fake_d_local = self.D["p"]["local"](fake_flooded * m)
1326
+
1327
+ real_d_global = self.D["p"]["global"](x)
1328
+
1329
+ # Note: discriminator returns [out_1,...,out_num_D] outputs
1330
+ # Each out_i is a list [feat1, feat2, ..., pred_i]
1331
+
1332
+ self.logger.losses.gen.p.gan = 0
1333
+
1334
+ loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False)
1335
+ loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False)
1336
+ loss *= lambdas.G["p"]["gan"]
1337
+
1338
+ self.logger.losses.gen.p.gan = loss.item()
1339
+
1340
+ step_loss += loss
1341
+
1342
+ # -----------------------------------
1343
+ # ----- Feature Matching Loss -----
1344
+ # -----------------------------------
1345
+ # (only on global discriminator)
1346
+ # Order must be real, fake
1347
+ if self.opts.dis.p.get_intermediate_features:
1348
+ loss = self.losses["G"]["p"]["featmatch"](real_d_global, fake_d_global)
1349
+ loss *= lambdas.G["p"]["featmatch"]
1350
+
1351
+ if isinstance(loss, float):
1352
+ self.logger.losses.gen.p.featmatch = loss
1353
+ else:
1354
+ self.logger.losses.gen.p.featmatch = loss.item()
1355
+
1356
+ step_loss += loss
1357
+
1358
+ # -------------------------------------------
1359
+ # ----- Single Discriminator GAN Loss -----
1360
+ # -------------------------------------------
1361
+ else:
1362
+ real_cat = torch.cat([m, x], axis=1)
1363
+ fake_cat = torch.cat([m, fake_flooded], axis=1)
1364
+ real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
1365
+
1366
+ real_fake_d = self.D["p"](real_fake_cat)
1367
+ real_d, fake_d = divide_pred(real_fake_d)
1368
+
1369
+ loss = self.losses["G"]["p"]["gan"](fake_d, True, False)
1370
+ self.logger.losses.gen.p.gan = loss.item()
1371
+ step_loss += loss
1372
+
1373
+ # -----------------------------------
1374
+ # ----- Feature Matching Loss -----
1375
+ # -----------------------------------
1376
+ if self.opts.dis.p.get_intermediate_features and lambdas.G.p.featmatch != 0:
1377
+ loss = self.losses["G"]["p"]["featmatch"](real_d, fake_d)
1378
+ loss *= lambdas.G.p.featmatch
1379
+
1380
+ if isinstance(loss, float):
1381
+ self.logger.losses.gen.p.featmatch = loss
1382
+ else:
1383
+ self.logger.losses.gen.p.featmatch = loss.item()
1384
+
1385
+ step_loss += loss
1386
+
1387
+ return step_loss
1388
+
1389
+ def masker_d_loss(self, x, z, target, domain, for_="G"):
1390
+ assert for_ in {"G", "D"}
1391
+ self.assert_z_matches_x(x, z)
1392
+ assert x.shape[0] == target.shape[0]
1393
+ zero_loss = torch.tensor(0.0, device=self.device)
1394
+ weight = self.opts.train.lambdas.G.d.main
1395
+
1396
+ prediction, z_depth = self.G.decoders["d"](z)
1397
+
1398
+ if self.opts.gen.d.classify.enable:
1399
+ target.squeeze_(1)
1400
+
1401
+ full_loss = self.losses["G"]["tasks"]["d"](prediction, target)
1402
+ full_loss *= weight
1403
+
1404
+ if weight == 0 or (domain == "r" and "d" not in self.pseudo_training_tasks):
1405
+ return zero_loss, prediction, z_depth
1406
+
1407
+ return full_loss, prediction, z_depth
1408
+
1409
+ def masker_s_loss(self, x, z, depth_preds, z_depth, target, domain, for_="G"):
1410
+ assert for_ in {"G", "D"}
1411
+ assert domain in {"r", "s"}
1412
+ self.assert_z_matches_x(x, z)
1413
+ assert x.shape[0] == target.shape[0] if target is not None else True
1414
+ full_loss = torch.tensor(0.0, device=self.device)
1415
+ softmax_preds = None
1416
+ # --------------------------
1417
+ # ----- Segmentation -----
1418
+ # --------------------------
1419
+ pred = None
1420
+ if for_ == "G" or self.opts.gen.s.use_advent:
1421
+ pred = self.G.decoders["s"](z, z_depth)
1422
+
1423
+ # Supervised segmentation loss: crossent for sim domain,
1424
+ # crossent_pseudo for real ; loss is crossent in any case
1425
+ if for_ == "G":
1426
+ if domain == "s" or "s" in self.pseudo_training_tasks:
1427
+ if domain == "s":
1428
+ logger = self.logger.losses.gen.task["s"]["crossent"]
1429
+ weight = self.opts.train.lambdas.G["s"]["crossent"]
1430
+ else:
1431
+ logger = self.logger.losses.gen.task["s"]["crossent_pseudo"]
1432
+ weight = self.opts.train.lambdas.G["s"]["crossent_pseudo"]
1433
+
1434
+ if weight != 0:
1435
+ # Cross-Entropy loss
1436
+ loss_func = self.losses["G"]["tasks"]["s"]["crossent"]
1437
+ loss = loss_func(pred, target.squeeze(1))
1438
+ loss *= weight
1439
+ full_loss += loss
1440
+ logger[domain] = loss.item()
1441
+
1442
+ if domain == "r":
1443
+ weight = self.opts.train.lambdas.G["s"]["minent"]
1444
+ if self.opts.gen.s.use_minent and weight != 0:
1445
+ softmax_preds = softmax(pred, dim=1)
1446
+ # Entropy minimization loss
1447
+ loss = self.losses["G"]["tasks"]["s"]["minent"](softmax_preds)
1448
+ loss *= weight
1449
+ full_loss += loss
1450
+
1451
+ self.logger.losses.gen.task["s"]["minent"]["r"] = loss.item()
1452
+
1453
+ # Fool ADVENT discriminator
1454
+ if self.opts.gen.s.use_advent:
1455
+ if self.opts.gen.s.use_dada and depth_preds is not None:
1456
+ depth_preds = depth_preds.detach()
1457
+ else:
1458
+ depth_preds = None
1459
+
1460
+ if for_ == "D":
1461
+ domain_label = domain
1462
+ logger = {}
1463
+ loss_func = self.losses["D"]["advent"]
1464
+ pred = pred.detach()
1465
+ weight = self.opts.train.lambdas.advent.adv_main
1466
+ else:
1467
+ domain_label = "s"
1468
+ logger = self.logger.losses.gen.task["s"]["advent"]
1469
+ loss_func = self.losses["G"]["tasks"]["s"]["advent"]
1470
+ weight = self.opts.train.lambdas.G["s"]["advent"]
1471
+
1472
+ if (for_ == "D" or domain == "r") and weight != 0:
1473
+ if softmax_preds is None:
1474
+ softmax_preds = softmax(pred, dim=1)
1475
+ loss = loss_func(
1476
+ softmax_preds,
1477
+ self.domain_labels[domain_label],
1478
+ self.D["s"]["Advent"],
1479
+ depth_preds,
1480
+ )
1481
+ loss *= weight
1482
+ full_loss += loss
1483
+ logger[domain] = loss.item()
1484
+
1485
+ if for_ == "D":
1486
+ # WGAN: clipping or GP
1487
+ if self.opts.dis.s.gan_type == "GAN" or "WGAN_norm":
1488
+ pass
1489
+ elif self.opts.dis.s.gan_type == "WGAN":
1490
+ for p in self.D["s"]["Advent"].parameters():
1491
+ p.data.clamp_(
1492
+ self.opts.dis.s.wgan_clamp_lower,
1493
+ self.opts.dis.s.wgan_clamp_upper,
1494
+ )
1495
+ elif self.opts.dis.s.gan_type == "WGAN_gp":
1496
+ prob_need_grad = autograd.Variable(pred, requires_grad=True)
1497
+ d_out = self.D["s"]["Advent"](prob_need_grad)
1498
+ gp = get_WGAN_gradient(prob_need_grad, d_out)
1499
+ gp_loss = gp * self.opts.train.lambdas.advent.WGAN_gp
1500
+ full_loss += gp_loss
1501
+ else:
1502
+ raise NotImplementedError
1503
+
1504
+ return full_loss, pred
1505
+
1506
+ def masker_m_loss(
1507
+ self, x, z, target, domain, for_="G", cond=None, z_depth=None, depth_preds=None
1508
+ ):
1509
+ assert for_ in {"G", "D"}
1510
+ assert domain in {"r", "s"}
1511
+ self.assert_z_matches_x(x, z)
1512
+ assert x.shape[0] == target.shape[0] if target is not None else True
1513
+ full_loss = torch.tensor(0.0, device=self.device)
1514
+
1515
+ pred_logits = self.G.decoders["m"](z, cond=cond, z_depth=z_depth)
1516
+ pred_prob = sigmoid(pred_logits)
1517
+ pred_prob_complementary = 1 - pred_prob
1518
+ prob = torch.cat([pred_prob, pred_prob_complementary], dim=1)
1519
+
1520
+ if for_ == "G":
1521
+ # TV loss
1522
+ weight = self.opts.train.lambdas.G.m.tv
1523
+ if weight != 0:
1524
+ loss = self.losses["G"]["tasks"]["m"]["tv"](pred_prob)
1525
+ loss *= weight
1526
+ full_loss += loss
1527
+
1528
+ self.logger.losses.gen.task["m"]["tv"][domain] = loss.item()
1529
+
1530
+ weight = self.opts.train.lambdas.G.m.bce
1531
+ if domain == "s" and weight != 0:
1532
+ # CrossEnt Loss
1533
+ loss = self.losses["G"]["tasks"]["m"]["bce"](pred_logits, target)
1534
+ loss *= weight
1535
+ full_loss += loss
1536
+ self.logger.losses.gen.task["m"]["bce"]["s"] = loss.item()
1537
+
1538
+ if domain == "r":
1539
+
1540
+ weight = self.opts.train.lambdas.G["m"]["gi"]
1541
+ if self.opts.gen.m.use_ground_intersection and weight != 0:
1542
+ # GroundIntersection loss
1543
+ loss = self.losses["G"]["tasks"]["m"]["gi"](pred_prob, target)
1544
+ loss *= weight
1545
+ full_loss += loss
1546
+ self.logger.losses.gen.task["m"]["gi"]["r"] = loss.item()
1547
+
1548
+ weight = self.opts.train.lambdas.G.m.pl4m
1549
+ if self.use_pl4m and weight != 0:
1550
+ # Painter loss
1551
+ pl4m_loss = self.painter_loss_for_masker(x, pred_prob)
1552
+ pl4m_loss *= weight
1553
+ full_loss += pl4m_loss
1554
+ self.logger.losses.gen.task.m.pl4m.r = pl4m_loss.item()
1555
+
1556
+ weight = self.opts.train.lambdas.advent.ent_main
1557
+ if self.opts.gen.m.use_minent and weight != 0:
1558
+ # MinEnt loss
1559
+ loss = self.losses["G"]["tasks"]["m"]["minent"](prob)
1560
+ loss *= weight
1561
+ full_loss += loss
1562
+ self.logger.losses.gen.task["m"]["minent"]["r"] = loss.item()
1563
+
1564
+ if self.opts.gen.m.use_advent:
1565
+ # AdvEnt loss
1566
+ if self.opts.gen.m.use_dada and depth_preds is not None:
1567
+ depth_preds = depth_preds.detach()
1568
+ depth_preds = torch.nn.functional.interpolate(
1569
+ depth_preds, size=x.shape[-2:], mode="nearest"
1570
+ )
1571
+ else:
1572
+ depth_preds = None
1573
+
1574
+ if for_ == "D":
1575
+ domain_label = domain
1576
+ logger = {}
1577
+ loss_func = self.losses["D"]["advent"]
1578
+ prob = prob.detach()
1579
+ weight = self.opts.train.lambdas.advent.adv_main
1580
+ else:
1581
+ domain_label = "s"
1582
+ logger = self.logger.losses.gen.task["m"]["advent"]
1583
+ loss_func = self.losses["G"]["tasks"]["m"]["advent"]
1584
+ weight = self.opts.train.lambdas.advent.adv_main
1585
+
1586
+ if (for_ == "D" or domain == "r") and weight != 0:
1587
+ loss = loss_func(
1588
+ prob.to(self.device),
1589
+ self.domain_labels[domain_label],
1590
+ self.D["m"]["Advent"],
1591
+ depth_preds,
1592
+ )
1593
+ loss *= weight
1594
+ full_loss += loss
1595
+ logger[domain] = loss.item()
1596
+
1597
+ if for_ == "D":
1598
+ # WGAN: clipping or GP
1599
+ if self.opts.dis.m.gan_type == "GAN" or "WGAN_norm":
1600
+ pass
1601
+ elif self.opts.dis.m.gan_type == "WGAN":
1602
+ for p in self.D["s"]["Advent"].parameters():
1603
+ p.data.clamp_(
1604
+ self.opts.dis.m.wgan_clamp_lower,
1605
+ self.opts.dis.m.wgan_clamp_upper,
1606
+ )
1607
+ elif self.opts.dis.m.gan_type == "WGAN_gp":
1608
+ prob_need_grad = autograd.Variable(prob, requires_grad=True)
1609
+ d_out = self.D["s"]["Advent"](prob_need_grad)
1610
+ gp = get_WGAN_gradient(prob_need_grad, d_out)
1611
+ gp_loss = self.opts.train.lambdas.advent.WGAN_gp * gp
1612
+ full_loss += gp_loss
1613
+ else:
1614
+ raise NotImplementedError
1615
+
1616
+ return full_loss, prob
1617
+
1618
+ def painter_loss_for_masker(self, x, m):
1619
+ # pl4m loss
1620
+ # painter should not be updated
1621
+ for param in self.G.painter.parameters():
1622
+ param.requires_grad = False
1623
+ # TODO for param in self.D.painter.parameters():
1624
+ # param.requires_grad = False
1625
+
1626
+ fake_flooded = self.G.paint(m, x)
1627
+
1628
+ if self.opts.dis.p.use_local_discriminator:
1629
+ fake_d_global = self.D["p"]["global"](fake_flooded)
1630
+ fake_d_local = self.D["p"]["local"](fake_flooded * m)
1631
+
1632
+ # Note: discriminator returns [out_1,...,out_num_D] outputs
1633
+ # Each out_i is a list [feat1, feat2, ..., pred_i]
1634
+
1635
+ pl4m_loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False)
1636
+ pl4m_loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False)
1637
+ else:
1638
+ real_cat = torch.cat([m, x], axis=1)
1639
+ fake_cat = torch.cat([m, fake_flooded], axis=1)
1640
+ real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
1641
+
1642
+ real_fake_d = self.D["p"](real_fake_cat)
1643
+ _, fake_d = divide_pred(real_fake_d)
1644
+
1645
+ pl4m_loss = self.losses["G"]["p"]["gan"](fake_d, True, False)
1646
+
1647
+ if "p" in self.opts.tasks:
1648
+ for param in self.G.painter.parameters():
1649
+ param.requires_grad = True
1650
+
1651
+ return pl4m_loss
1652
+
1653
+ @torch.no_grad()
1654
+ def run_evaluation(self, verbose=0):
1655
+ print("******************* Running Evaluation ***********************")
1656
+ start_time = time()
1657
+ self.eval_mode()
1658
+ val_logger = None
1659
+ nb_of_batches = None
1660
+ for i, multi_batch_tuple in enumerate(self.val_loaders):
1661
+ # create a dictionnary (domain => batch) from tuple
1662
+ # (batch_domain_0, ..., batch_domain_i)
1663
+ # and send it to self.device
1664
+ nb_of_batches = i + 1
1665
+ multi_domain_batch = {
1666
+ batch["domain"][0]: self.batch_to_device(batch)
1667
+ for batch in multi_batch_tuple
1668
+ }
1669
+ self.get_G_loss(multi_domain_batch, verbose)
1670
+
1671
+ if val_logger is None:
1672
+ val_logger = deepcopy(self.logger.losses.generator)
1673
+ else:
1674
+ val_logger = sum_dict(val_logger, self.logger.losses.generator)
1675
+
1676
+ val_logger = div_dict(val_logger, nb_of_batches)
1677
+ self.logger.losses.generator = val_logger
1678
+ self.logger.log_losses(model_to_update="G", mode="val")
1679
+
1680
+ for d in self.opts.domains:
1681
+ self.logger.log_comet_images("train", d)
1682
+ self.logger.log_comet_images("val", d)
1683
+
1684
+ if "m" in self.opts.tasks and self.has_painter and not self.kitti_pretrain:
1685
+ self.logger.log_comet_combined_images("train", "r")
1686
+ self.logger.log_comet_combined_images("val", "r")
1687
+
1688
+ if self.exp is not None:
1689
+ print()
1690
+
1691
+ if "m" in self.opts.tasks or "s" in self.opts.tasks:
1692
+ self.eval_images("val", "r")
1693
+ self.eval_images("val", "s")
1694
+
1695
+ if "p" in self.opts.tasks and not self.kitti_pretrain:
1696
+ val_fid = compute_val_fid(self)
1697
+ if self.exp is not None:
1698
+ self.exp.log_metric("val_fid", val_fid, step=self.logger.global_step)
1699
+ else:
1700
+ print("Validation FID Score", val_fid)
1701
+
1702
+ self.train_mode()
1703
+ timing = int(time() - start_time)
1704
+ print("****************** Done in {}s *********************".format(timing))
1705
+
1706
+ def eval_images(self, mode, domain):
1707
+ if domain == "s" and self.kitti_pretrain:
1708
+ domain = "kitti"
1709
+ if domain == "rf" or domain not in self.display_images[mode]:
1710
+ return
1711
+
1712
+ metric_funcs = {"accuracy": accuracy, "mIOU": mIOU}
1713
+ metric_avg_scores = {"m": {}}
1714
+ if "s" in self.opts.tasks:
1715
+ metric_avg_scores["s"] = {}
1716
+ if "d" in self.opts.tasks and domain == "s" and self.opts.gen.d.classify.enable:
1717
+ metric_avg_scores["d"] = {}
1718
+
1719
+ for key in metric_funcs:
1720
+ for task in metric_avg_scores:
1721
+ metric_avg_scores[task][key] = []
1722
+
1723
+ for im_set in self.display_images[mode][domain]:
1724
+ x = im_set["data"]["x"].unsqueeze(0).to(self.device)
1725
+ z = self.G.encode(x)
1726
+
1727
+ s_pred = d_pred = z_depth = None
1728
+
1729
+ if "d" in metric_avg_scores:
1730
+ d_pred, z_depth = self.G.decoders["d"](z)
1731
+ d_pred = d_pred.detach().cpu()
1732
+
1733
+ if domain == "s":
1734
+ d = im_set["data"]["d"].unsqueeze(0).detach()
1735
+
1736
+ for metric in metric_funcs:
1737
+ metric_score = metric_funcs[metric](d_pred, d)
1738
+ metric_avg_scores["d"][metric].append(metric_score)
1739
+
1740
+ if "s" in metric_avg_scores:
1741
+ if z_depth is None:
1742
+ if self.opts.gen.s.use_dada and "d" in self.opts.tasks:
1743
+ _, z_depth = self.G.decoders["d"](z)
1744
+ s_pred = self.G.decoders["s"](z, z_depth).detach().cpu()
1745
+ s = im_set["data"]["s"].unsqueeze(0).detach()
1746
+
1747
+ for metric in metric_funcs:
1748
+ metric_score = metric_funcs[metric](s_pred, s)
1749
+ metric_avg_scores["s"][metric].append(metric_score)
1750
+
1751
+ if "m" in self.opts:
1752
+ cond = None
1753
+ if s_pred is not None and d_pred is not None:
1754
+ cond = self.G.make_m_cond(d_pred, s_pred, x)
1755
+ if z_depth is None:
1756
+ if self.opts.gen.m.use_dada and "d" in self.opts.tasks:
1757
+ _, z_depth = self.G.decoders["d"](z)
1758
+
1759
+ pred_mask = (
1760
+ (self.G.mask(z=z, cond=cond, z_depth=z_depth)).detach().cpu()
1761
+ )
1762
+ pred_mask = (pred_mask > 0.5).to(torch.float32)
1763
+ pred_prob = torch.cat([1 - pred_mask, pred_mask], dim=1)
1764
+
1765
+ m = im_set["data"]["m"].unsqueeze(0).detach()
1766
+
1767
+ for metric in metric_funcs:
1768
+ if metric != "mIOU":
1769
+ metric_score = metric_funcs[metric](pred_mask, m)
1770
+ else:
1771
+ metric_score = metric_funcs[metric](pred_prob, m)
1772
+
1773
+ metric_avg_scores["m"][metric].append(metric_score)
1774
+
1775
+ metric_avg_scores = {
1776
+ task: {
1777
+ metric: np.mean(values) if values else float("nan")
1778
+ for metric, values in met_dict.items()
1779
+ }
1780
+ for task, met_dict in metric_avg_scores.items()
1781
+ }
1782
+ metric_avg_scores = {
1783
+ task: {
1784
+ metric: value if not np.isnan(value) else -1
1785
+ for metric, value in met_dict.items()
1786
+ }
1787
+ for task, met_dict in metric_avg_scores.items()
1788
+ }
1789
+ if self.exp is not None:
1790
+ self.exp.log_metrics(
1791
+ flatten_opts(metric_avg_scores),
1792
+ prefix=f"metrics_{mode}_{domain}",
1793
+ step=self.logger.global_step,
1794
+ )
1795
+ else:
1796
+ print(f"metrics_{mode}_{domain}")
1797
+ print(flatten_opts(metric_avg_scores))
1798
+
1799
+ return 0
1800
+
1801
+ def functional_test_mode(self):
1802
+ import atexit
1803
+
1804
+ self.opts.output_path = (
1805
+ Path("~").expanduser() / "climategan" / "functional_tests"
1806
+ )
1807
+ Path(self.opts.output_path).mkdir(parents=True, exist_ok=True)
1808
+ with open(Path(self.opts.output_path) / "is_functional.test", "w") as f:
1809
+ f.write("trainer functional test - delete this dir")
1810
+
1811
+ if self.exp is not None:
1812
+ self.exp.log_parameter("is_functional_test", True)
1813
+ atexit.register(self.del_output_path)
1814
+
1815
+ def del_output_path(self, force=False):
1816
+ import shutil
1817
+
1818
+ if not Path(self.opts.output_path).exists():
1819
+ return
1820
+
1821
+ if (Path(self.opts.output_path) / "is_functional.test").exists() or force:
1822
+ shutil.rmtree(self.opts.output_path)
1823
+
1824
+ def compute_fire(self, x, seg_preds=None, z=None, z_depth=None):
1825
+ """
1826
+ Transforms input tensor given wildfires event
1827
+ Args:
1828
+ x (torch.Tensor): Input tensor
1829
+ seg_preds (torch.Tensor): Semantic segmentation
1830
+ predictions for input tensor
1831
+ z (torch.Tensor): Latent vector of encoded "x".
1832
+ Can be None if seg_preds is given.
1833
+ Returns:
1834
+ torch.Tensor: Wildfire version of input tensor
1835
+ """
1836
+
1837
+ if seg_preds is None:
1838
+ if z is None:
1839
+ z = self.G.encode(x)
1840
+ seg_preds = self.G.decoders["s"](z, z_depth)
1841
+
1842
+ return add_fire(x, seg_preds, self.opts.events.fire)
1843
+
1844
+ def compute_flood(
1845
+ self, x, z=None, z_depth=None, m=None, s=None, cloudy=None, bin_value=-1
1846
+ ):
1847
+ """
1848
+ Applies a flood (mask + paint) to an input image, with optionally
1849
+ pre-computed masker z or mask
1850
+
1851
+ Args:
1852
+ x (torch.Tensor): B x C x H x W -1:1 input image
1853
+ z (torch.Tensor, optional): B x C x H x W Masker latent vector.
1854
+ Defaults to None.
1855
+ m (torch.Tensor, optional): B x 1 x H x W Mask. Defaults to None.
1856
+ bin_value (float, optional): Mask binarization value.
1857
+ Set to -1 to use smooth masks (no binarization)
1858
+
1859
+ Returns:
1860
+ torch.Tensor: B x 3 x H x W -1:1 flooded image
1861
+ """
1862
+
1863
+ if m is None:
1864
+ if z is None:
1865
+ z = self.G.encode(x)
1866
+ if "d" in self.opts.tasks and self.opts.gen.m.use_dada and z_depth is None:
1867
+ _, z_depth = self.G.decoders["d"](z)
1868
+ m = self.G.mask(x=x, z=z, z_depth=z_depth)
1869
+
1870
+ if bin_value >= 0:
1871
+ m = (m > bin_value).to(m.dtype)
1872
+
1873
+ if cloudy:
1874
+ assert s is not None
1875
+ return self.G.paint_cloudy(m, x, s)
1876
+
1877
+ return self.G.paint(m, x)
1878
+
1879
+ def compute_smog(self, x, z=None, d=None, s=None, use_sky_seg=False):
1880
+ # implementation from the paper:
1881
+ # HazeRD: An outdoor scene dataset and benchmark for single image dehazing
1882
+ sky_mask = None
1883
+ if d is None or (use_sky_seg and s is None):
1884
+ if z is None:
1885
+ z = self.G.encode(x)
1886
+ if d is None:
1887
+ d, _ = self.G.decoders["d"](z)
1888
+ if use_sky_seg and s is None:
1889
+ if "s" not in self.opts.tasks:
1890
+ raise ValueError(
1891
+ "Cannot have "
1892
+ + "(use_sky_seg is True and s is None and 's' not in tasks)"
1893
+ )
1894
+ s = self.G.decoders["s"](z)
1895
+ # TODO: s to sky mask
1896
+ # TODO: interpolate to d's size
1897
+
1898
+ params = self.opts.events.smog
1899
+
1900
+ airlight = params.airlight * torch.ones(3)
1901
+ airlight = airlight.view(1, -1, 1, 1).to(self.device)
1902
+
1903
+ irradiance = srgb2lrgb(x)
1904
+
1905
+ beta = torch.tensor([params.beta / params.vr] * 3)
1906
+ beta = beta.view(1, -1, 1, 1).to(self.device)
1907
+
1908
+ d = normalize(d, mini=0.3, maxi=1.0)
1909
+ d = 1.0 / d
1910
+ d = normalize(d, mini=0.1, maxi=1)
1911
+
1912
+ if sky_mask is not None:
1913
+ d[sky_mask] = 1
1914
+
1915
+ d = torch.nn.functional.interpolate(
1916
+ d, size=x.shape[-2:], mode="bilinear", align_corners=True
1917
+ )
1918
+
1919
+ d = d.repeat(1, 3, 1, 1)
1920
+
1921
+ transmission = torch.exp(d * -beta)
1922
+
1923
+ smogged = transmission * irradiance + (1 - transmission) * airlight
1924
+
1925
+ smogged = lrgb2srgb(smogged)
1926
+
1927
+ # add yellow filter
1928
+ alpha = params.alpha / 255
1929
+ yellow_mask = torch.Tensor([params.yellow_color]) / 255
1930
+ yellow_filter = (
1931
+ yellow_mask.unsqueeze(2)
1932
+ .unsqueeze(2)
1933
+ .repeat(1, 1, smogged.shape[-2], smogged.shape[-1])
1934
+ .to(self.device)
1935
+ )
1936
+
1937
+ smogged = smogged * (1 - alpha) + yellow_filter * alpha
1938
+
1939
+ return smogged
climategan/transforms.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data transforms for the loaders
2
+ """
3
+ import random
4
+ import traceback
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from skimage.color import rgba2rgb
11
+ from skimage.io import imread
12
+ from torchvision import transforms as trsfs
13
+ from torchvision.transforms.functional import (
14
+ adjust_brightness,
15
+ adjust_contrast,
16
+ adjust_saturation,
17
+ )
18
+
19
+ from climategan.tutils import normalize
20
+
21
+
22
+ def interpolation(task):
23
+ if task in ["d", "m", "s"]:
24
+ return {"mode": "nearest"}
25
+ else:
26
+ return {"mode": "bilinear", "align_corners": True}
27
+
28
+
29
+ class Resize:
30
+ def __init__(self, target_size, keep_aspect_ratio=False):
31
+ """
32
+ Resize transform. Target_size can be an int or a tuple of ints,
33
+ depending on whether both height and width should have the same
34
+ final size or not.
35
+
36
+ If keep_aspect_ratio is specified then target_size must be an int:
37
+ the smallest dimension of x will be set to target_size and the largest
38
+ dimension will be computed to the closest int keeping the original
39
+ aspect ratio. e.g.
40
+ >>> x = torch.rand(1, 3, 1200, 1800)
41
+ >>> m = torch.rand(1, 1, 600, 600)
42
+ >>> d = {"x": x, "m": m}
43
+ >>> {k: v.shape for k, v in Resize(640, True)(d).items()}
44
+ {"x": (1, 3, 640, 960), "m": (1, 1, 640, 960)}
45
+
46
+
47
+
48
+ Args:
49
+ target_size (int | tuple(int)): New size for the tensor
50
+ keep_aspect_ratio (bool, optional): Whether or not to keep aspect ratio
51
+ when resizing. Requires target_size to be an int. If keeping aspect
52
+ ratio, smallest dim will be set to target_size. Defaults to False.
53
+ """
54
+ if isinstance(target_size, (int, tuple, list)):
55
+ if not isinstance(target_size, int) and not keep_aspect_ratio:
56
+ assert len(target_size) == 2
57
+ self.h, self.w = target_size
58
+ else:
59
+ if keep_aspect_ratio:
60
+ assert isinstance(target_size, int)
61
+ self.h = self.w = target_size
62
+
63
+ self.default_h = int(self.h)
64
+ self.default_w = int(self.w)
65
+ self.sizes = {}
66
+ elif isinstance(target_size, dict):
67
+ assert (
68
+ not keep_aspect_ratio
69
+ ), "dict target_size not compatible with keep_aspect_ratio"
70
+
71
+ self.sizes = {
72
+ k: {"h": v, "w": v} for k, v in target_size.items() if k != "default"
73
+ }
74
+ self.default_h = int(target_size["default"])
75
+ self.default_w = int(target_size["default"])
76
+
77
+ self.keep_aspect_ratio = keep_aspect_ratio
78
+
79
+ def compute_new_default_size(self, tensor):
80
+ """
81
+ compute the new size for a tensor depending on target size
82
+ and keep_aspect_rato
83
+
84
+ Args:
85
+ tensor (torch.Tensor): 4D tensor N x C x H x W.
86
+
87
+ Returns:
88
+ tuple(int): (new_height, new_width)
89
+ """
90
+ if self.keep_aspect_ratio:
91
+ h, w = tensor.shape[-2:]
92
+ if h < w:
93
+ return (self.h, int(self.default_h * w / h))
94
+ else:
95
+ return (int(self.default_h * h / w), self.default_w)
96
+ return (self.default_h, self.default_w)
97
+
98
+ def compute_new_size_for_task(self, task):
99
+ assert (
100
+ not self.keep_aspect_ratio
101
+ ), "compute_new_size_for_task is not compatible with keep aspect ratio"
102
+
103
+ if task not in self.sizes:
104
+ return (self.default_h, self.default_w)
105
+
106
+ return (self.sizes[task]["h"], self.sizes[task]["w"])
107
+
108
+ def __call__(self, data):
109
+ """
110
+ Resize a dict of tensors to the "x" key's new_size
111
+
112
+ Args:
113
+ data (dict[str:torch.Tensor]): The data dict to transform
114
+
115
+ Returns:
116
+ dict[str: torch.Tensor]: dict with all tensors resized to the
117
+ new size of the data["x"] tensor
118
+ """
119
+ task = tensor = new_size = None
120
+ try:
121
+ if not self.sizes:
122
+ d = {}
123
+ new_size = self.compute_new_default_size(
124
+ data["x"] if "x" in data else list(data.values())[0]
125
+ )
126
+ for task, tensor in data.items():
127
+ d[task] = F.interpolate(
128
+ tensor, size=new_size, **interpolation(task)
129
+ )
130
+ return d
131
+
132
+ d = {}
133
+ for task, tensor in data.items():
134
+ new_size = self.compute_new_size_for_task(task)
135
+ d[task] = F.interpolate(tensor, size=new_size, **interpolation(task))
136
+ return d
137
+
138
+ except Exception as e:
139
+ tb = traceback.format_exc()
140
+ print("Debug: task, shape, interpolation, h, w, new_size")
141
+ print(task)
142
+ print(tensor.shape)
143
+ print(interpolation(task))
144
+ print(self.h, self.w)
145
+ print(new_size)
146
+ print(tb)
147
+ raise Exception(e)
148
+
149
+
150
+ class RandomCrop:
151
+ def __init__(self, size, center=False):
152
+ assert isinstance(size, (int, tuple, list))
153
+ if not isinstance(size, int):
154
+ assert len(size) == 2
155
+ self.h, self.w = size
156
+ else:
157
+ self.h = self.w = size
158
+
159
+ self.h = int(self.h)
160
+ self.w = int(self.w)
161
+ self.center = center
162
+
163
+ def __call__(self, data):
164
+ H, W = (
165
+ data["x"].size()[-2:] if "x" in data else list(data.values())[0].size()[-2:]
166
+ )
167
+
168
+ if not self.center:
169
+ top = np.random.randint(0, H - self.h)
170
+ left = np.random.randint(0, W - self.w)
171
+ else:
172
+ top = (H - self.h) // 2
173
+ left = (W - self.w) // 2
174
+
175
+ return {
176
+ task: tensor[:, :, top : top + self.h, left : left + self.w]
177
+ for task, tensor in data.items()
178
+ }
179
+
180
+
181
+ class RandomHorizontalFlip:
182
+ def __init__(self, p=0.5):
183
+ # self.flip = TF.hflip
184
+ self.p = p
185
+
186
+ def __call__(self, data):
187
+ if np.random.rand() > self.p:
188
+ return data
189
+ return {task: torch.flip(tensor, [3]) for task, tensor in data.items()}
190
+
191
+
192
+ class ToTensor:
193
+ def __init__(self):
194
+ self.ImagetoTensor = trsfs.ToTensor()
195
+ self.MaptoTensor = self.ImagetoTensor
196
+
197
+ def __call__(self, data):
198
+ new_data = {}
199
+ for task, im in data.items():
200
+ if task in {"x", "a"}:
201
+ new_data[task] = self.ImagetoTensor(im)
202
+ elif task in {"m"}:
203
+ new_data[task] = self.MaptoTensor(im)
204
+ elif task == "s":
205
+ new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to(
206
+ torch.int64
207
+ )
208
+ elif task == "d":
209
+ new_data = im
210
+
211
+ return new_data
212
+
213
+
214
+ class Normalize:
215
+ def __init__(self, opts):
216
+ if opts.data.normalization == "HRNet":
217
+ self.normImage = trsfs.Normalize(
218
+ ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
219
+ )
220
+ else:
221
+ self.normImage = trsfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
222
+ self.normDepth = lambda x: x
223
+ self.normMask = lambda x: x
224
+ self.normSeg = lambda x: x
225
+
226
+ self.normalize = {
227
+ "x": self.normImage,
228
+ "s": self.normSeg,
229
+ "d": self.normDepth,
230
+ "m": self.normMask,
231
+ }
232
+
233
+ def __call__(self, data):
234
+ return {
235
+ task: self.normalize.get(task, lambda x: x)(tensor.squeeze(0))
236
+ for task, tensor in data.items()
237
+ }
238
+
239
+
240
+ class RandBrightness: # Input need to be between -1 and 1
241
+ def __call__(self, data):
242
+ return {
243
+ task: rand_brightness(tensor) if task == "x" else tensor
244
+ for task, tensor in data.items()
245
+ }
246
+
247
+
248
+ class RandSaturation:
249
+ def __call__(self, data):
250
+ return {
251
+ task: rand_saturation(tensor) if task == "x" else tensor
252
+ for task, tensor in data.items()
253
+ }
254
+
255
+
256
+ class RandContrast:
257
+ def __call__(self, data):
258
+ return {
259
+ task: rand_contrast(tensor) if task == "x" else tensor
260
+ for task, tensor in data.items()
261
+ }
262
+
263
+
264
+ class BucketizeDepth:
265
+ def __init__(self, opts, domain):
266
+ self.domain = domain
267
+
268
+ if opts.gen.d.classify.enable and domain in {"s", "kitti"}:
269
+ self.buckets = torch.linspace(
270
+ *[
271
+ opts.gen.d.classify.linspace.min,
272
+ opts.gen.d.classify.linspace.max,
273
+ opts.gen.d.classify.linspace.buckets - 1,
274
+ ]
275
+ )
276
+
277
+ self.transforms = {
278
+ "d": lambda tensor: torch.bucketize(
279
+ tensor, self.buckets, out_int32=True, right=True
280
+ )
281
+ }
282
+ else:
283
+ self.transforms = {}
284
+
285
+ def __call__(self, data):
286
+ return {
287
+ task: self.transforms.get(task, lambda x: x)(tensor)
288
+ for task, tensor in data.items()
289
+ }
290
+
291
+
292
+ class PrepareInference:
293
+ """
294
+ Transform which:
295
+ - transforms a str or an array into a tensor
296
+ - resizes the image to keep the aspect ratio
297
+ - crops in the center of the resized image
298
+ - normalize to 0:1
299
+ - rescale to -1:1
300
+ """
301
+
302
+ def __init__(self, target_size=640, half=False, is_label=False, enforce_128=True):
303
+ if enforce_128:
304
+ if target_size % 2 ** 7 != 0:
305
+ raise ValueError(
306
+ f"Received a target_size of {target_size}, which is not a "
307
+ + "multiple of 2^7 = 128. Set enforce_128 to False to disable "
308
+ + "this error."
309
+ )
310
+ self.resize = Resize(target_size, keep_aspect_ratio=True)
311
+ self.crop = RandomCrop((target_size, target_size), center=True)
312
+ self.half = half
313
+ self.is_label = is_label
314
+
315
+ def process(self, t):
316
+ if isinstance(t, (str, Path)):
317
+ t = imread(str(t))
318
+
319
+ if isinstance(t, np.ndarray):
320
+ if t.shape[-1] == 4:
321
+ t = rgba2rgb(t)
322
+
323
+ t = torch.from_numpy(t)
324
+ if t.ndim == 3:
325
+ t = t.permute(2, 0, 1)
326
+
327
+ if t.ndim == 3:
328
+ t = t.unsqueeze(0)
329
+ elif t.ndim == 2:
330
+ t = t.unsqueeze(0).unsqueeze(0)
331
+
332
+ if not self.is_label:
333
+ t = t.to(torch.float32)
334
+ t = normalize(t)
335
+ t = (t - 0.5) * 2
336
+
337
+ t = {"m": t} if self.is_label else {"x": t}
338
+ t = self.resize(t)
339
+ t = self.crop(t)
340
+ t = t["m"] if self.is_label else t["x"]
341
+
342
+ if self.half and not self.is_label:
343
+ t = t.half()
344
+
345
+ return t
346
+
347
+ def __call__(self, x):
348
+ """
349
+ normalize, rescale, resize, crop in the center
350
+
351
+ x can be: dict {"task": data} list [data, ..] or data
352
+ data ^ can be a str, a Path, a numpy arrray or a Tensor
353
+ """
354
+ if isinstance(x, dict):
355
+ return {k: self.process(v) for k, v in x.items()}
356
+
357
+ if isinstance(x, list):
358
+ return [self.process(t) for t in x]
359
+
360
+ return self.process(x)
361
+
362
+
363
+ class PrepareTest:
364
+ """
365
+ Transform which:
366
+ - transforms a str or an array into a tensor
367
+ - resizes the image to keep the aspect ratio
368
+ - crops in the center of the resized image
369
+ - normalize to 0:1 (optional)
370
+ - rescale to -1:1 (optional)
371
+ """
372
+
373
+ def __init__(self, target_size=640, half=False):
374
+ self.resize = Resize(target_size, keep_aspect_ratio=True)
375
+ self.crop = RandomCrop((target_size, target_size), center=True)
376
+ self.half = half
377
+
378
+ def process(self, t, normalize=False, rescale=False):
379
+ if isinstance(t, (str, Path)):
380
+ # t = img_as_float(imread(str(t)))
381
+ t = imread(str(t))
382
+ if t.shape[-1] == 4:
383
+ # t = rgba2rgb(t)
384
+ t = t[:, :, :3]
385
+ if np.ndim(t) == 2:
386
+ t = np.repeat(t[:, :, np.newaxis], 3, axis=2)
387
+
388
+ if isinstance(t, np.ndarray):
389
+ t = torch.from_numpy(t)
390
+ t = t.permute(2, 0, 1)
391
+
392
+ if len(t.shape) == 3:
393
+ t = t.unsqueeze(0)
394
+
395
+ t = t.to(torch.float32)
396
+ normalize(t) if normalize else t
397
+ (t - 0.5) * 2 if rescale else t
398
+ t = {"x": t}
399
+ t = self.resize(t)
400
+ t = self.crop(t)
401
+ t = t["x"]
402
+
403
+ if self.half:
404
+ return t.to(torch.float16)
405
+
406
+ return t
407
+
408
+ def __call__(self, x, normalize=False, rescale=False):
409
+ """
410
+ Call process()
411
+
412
+ x can be: dict {"task": data} list [data, ..] or data
413
+ data ^ can be a str, a Path, a numpy arrray or a Tensor
414
+ """
415
+ if isinstance(x, dict):
416
+ return {k: self.process(v, normalize, rescale) for k, v in x.items()}
417
+
418
+ if isinstance(x, list):
419
+ return [self.process(t, normalize, rescale) for t in x]
420
+
421
+ return self.process(x, normalize, rescale)
422
+
423
+
424
+ def get_transform(transform_item, mode):
425
+ """Returns the torchivion transform function associated to a
426
+ transform_item listed in opts.data.transforms ; transform_item is
427
+ an addict.Dict
428
+ """
429
+
430
+ if transform_item.name == "crop" and not (
431
+ transform_item.ignore is True or transform_item.ignore == mode
432
+ ):
433
+ return RandomCrop(
434
+ (transform_item.height, transform_item.width),
435
+ center=transform_item.center == mode,
436
+ )
437
+
438
+ elif transform_item.name == "resize" and not (
439
+ transform_item.ignore is True or transform_item.ignore == mode
440
+ ):
441
+ return Resize(
442
+ transform_item.new_size, transform_item.get("keep_aspect_ratio", False)
443
+ )
444
+
445
+ elif transform_item.name == "hflip" and not (
446
+ transform_item.ignore is True or transform_item.ignore == mode
447
+ ):
448
+ return RandomHorizontalFlip(p=transform_item.p or 0.5)
449
+
450
+ elif transform_item.name == "brightness" and not (
451
+ transform_item.ignore is True or transform_item.ignore == mode
452
+ ):
453
+ return RandBrightness()
454
+
455
+ elif transform_item.name == "saturation" and not (
456
+ transform_item.ignore is True or transform_item.ignore == mode
457
+ ):
458
+ return RandSaturation()
459
+
460
+ elif transform_item.name == "contrast" and not (
461
+ transform_item.ignore is True or transform_item.ignore == mode
462
+ ):
463
+ return RandContrast()
464
+
465
+ elif transform_item.ignore is True or transform_item.ignore == mode:
466
+ return None
467
+
468
+ raise ValueError("Unknown transform_item {}".format(transform_item))
469
+
470
+
471
+ def get_transforms(opts, mode, domain):
472
+ """Get all the transform functions listed in opts.data.transforms
473
+ using get_transform(transform_item, mode)
474
+ """
475
+ transforms = []
476
+ color_jittering_transforms = ["brightness", "saturation", "contrast"]
477
+
478
+ for t in opts.data.transforms:
479
+ if t.name not in color_jittering_transforms:
480
+ transforms.append(get_transform(t, mode))
481
+
482
+ if "p" not in opts.tasks and mode == "train":
483
+ for t in opts.data.transforms:
484
+ if t.name in color_jittering_transforms:
485
+ transforms.append(get_transform(t, mode))
486
+
487
+ transforms += [Normalize(opts), BucketizeDepth(opts, domain)]
488
+ transforms = [t for t in transforms if t is not None]
489
+
490
+ return transforms
491
+
492
+
493
+ # ----- Adapted functions from https://github.com/mit-han-lab/data-efficient-gans -----#
494
+ def rand_brightness(tensor, is_diff_augment=False):
495
+ if is_diff_augment:
496
+ assert len(tensor.shape) == 4
497
+ type_ = tensor.dtype
498
+ device_ = tensor.device
499
+ rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
500
+ return tensor + (rand_tens - 0.5)
501
+ else:
502
+ factor = random.uniform(0.5, 1.5)
503
+ tensor = adjust_brightness(tensor, brightness_factor=factor)
504
+ # dummy pixels to fool scaling and preserve range
505
+ tensor[:, :, 0, 0] = 1.0
506
+ tensor[:, :, -1, -1] = 0.0
507
+ return tensor
508
+
509
+
510
+ def rand_saturation(tensor, is_diff_augment=False):
511
+ if is_diff_augment:
512
+ assert len(tensor.shape) == 4
513
+ type_ = tensor.dtype
514
+ device_ = tensor.device
515
+ rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
516
+ x_mean = tensor.mean(dim=1, keepdim=True)
517
+ return (tensor - x_mean) * (rand_tens * 2) + x_mean
518
+ else:
519
+ factor = random.uniform(0.5, 1.5)
520
+ tensor = adjust_saturation(tensor, saturation_factor=factor)
521
+ # dummy pixels to fool scaling and preserve range
522
+ tensor[:, :, 0, 0] = 1.0
523
+ tensor[:, :, -1, -1] = 0.0
524
+ return tensor
525
+
526
+
527
+ def rand_contrast(tensor, is_diff_augment=False):
528
+ if is_diff_augment:
529
+ assert len(tensor.shape) == 4
530
+ type_ = tensor.dtype
531
+ device_ = tensor.device
532
+ rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
533
+ x_mean = tensor.mean(dim=[1, 2, 3], keepdim=True)
534
+ return (tensor - x_mean) * (rand_tens + 0.5) + x_mean
535
+ else:
536
+ factor = random.uniform(0.5, 1.5)
537
+ tensor = adjust_contrast(tensor, contrast_factor=factor)
538
+ # dummy pixels to fool scaling and preserve range
539
+ tensor[:, :, 0, 0] = 1.0
540
+ tensor[:, :, -1, -1] = 0.0
541
+ return tensor
542
+
543
+
544
+ def rand_cutout(tensor, ratio=0.5):
545
+ assert len(tensor.shape) == 4, "For rand cutout, tensor must be 4D."
546
+ type_ = tensor.dtype
547
+ device_ = tensor.device
548
+ cutout_size = int(tensor.size(-2) * ratio + 0.5), int(tensor.size(-1) * ratio + 0.5)
549
+ grid_batch, grid_x, grid_y = torch.meshgrid(
550
+ torch.arange(tensor.size(0), dtype=torch.long, device=device_),
551
+ torch.arange(cutout_size[0], dtype=torch.long, device=device_),
552
+ torch.arange(cutout_size[1], dtype=torch.long, device=device_),
553
+ )
554
+ size_ = [tensor.size(0), 1, 1]
555
+ offset_x = torch.randint(
556
+ 0,
557
+ tensor.size(-2) + (1 - cutout_size[0] % 2),
558
+ size=size_,
559
+ device=device_,
560
+ )
561
+ offset_y = torch.randint(
562
+ 0,
563
+ tensor.size(-1) + (1 - cutout_size[1] % 2),
564
+ size=size_,
565
+ device=device_,
566
+ )
567
+ grid_x = torch.clamp(
568
+ grid_x + offset_x - cutout_size[0] // 2, min=0, max=tensor.size(-2) - 1
569
+ )
570
+ grid_y = torch.clamp(
571
+ grid_y + offset_y - cutout_size[1] // 2, min=0, max=tensor.size(-1) - 1
572
+ )
573
+ mask = torch.ones(
574
+ tensor.size(0), tensor.size(2), tensor.size(3), dtype=type_, device=device_
575
+ )
576
+ mask[grid_batch, grid_x, grid_y] = 0
577
+ return tensor * mask.unsqueeze(1)
578
+
579
+
580
+ def rand_translation(tensor, ratio=0.125):
581
+ assert len(tensor.shape) == 4, "For rand translation, tensor must be 4D."
582
+ device_ = tensor.device
583
+ shift_x, shift_y = (
584
+ int(tensor.size(2) * ratio + 0.5),
585
+ int(tensor.size(3) * ratio + 0.5),
586
+ )
587
+ translation_x = torch.randint(
588
+ -shift_x, shift_x + 1, size=[tensor.size(0), 1, 1], device=device_
589
+ )
590
+ translation_y = torch.randint(
591
+ -shift_y, shift_y + 1, size=[tensor.size(0), 1, 1], device=device_
592
+ )
593
+ grid_batch, grid_x, grid_y = torch.meshgrid(
594
+ torch.arange(tensor.size(0), dtype=torch.long, device=device_),
595
+ torch.arange(tensor.size(2), dtype=torch.long, device=device_),
596
+ torch.arange(tensor.size(3), dtype=torch.long, device=device_),
597
+ )
598
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, tensor.size(2) + 1)
599
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, tensor.size(3) + 1)
600
+ x_pad = F.pad(tensor, [1, 1, 1, 1, 0, 0, 0, 0])
601
+ tensor = (
602
+ x_pad.permute(0, 2, 3, 1)
603
+ .contiguous()[grid_batch, grid_x, grid_y]
604
+ .permute(0, 3, 1, 2)
605
+ )
606
+ return tensor
607
+
608
+
609
+ class DiffTransforms:
610
+ def __init__(self, diff_aug_opts):
611
+ self.do_color_jittering = diff_aug_opts.do_color_jittering
612
+ self.do_cutout = diff_aug_opts.do_cutout
613
+ self.do_translation = diff_aug_opts.do_translation
614
+ self.cutout_ratio = diff_aug_opts.cutout_ratio
615
+ self.translation_ratio = diff_aug_opts.translation_ratio
616
+
617
+ def __call__(self, tensor):
618
+ if self.do_color_jittering:
619
+ tensor = rand_brightness(tensor, is_diff_augment=True)
620
+ tensor = rand_contrast(tensor, is_diff_augment=True)
621
+ tensor = rand_saturation(tensor, is_diff_augment=True)
622
+ if self.do_translation:
623
+ tensor = rand_translation(tensor, ratio=self.translation_ratio)
624
+ if self.do_cutout:
625
+ tensor = rand_cutout(tensor, ratio=self.cutout_ratio)
626
+ return tensor
climategan/tutils.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tensor-utils
2
+ """
3
+ import io
4
+ import math
5
+ from contextlib import redirect_stdout
6
+ from pathlib import Path
7
+
8
+ # from copy import copy
9
+ from threading import Thread
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from skimage import io as skio
15
+ from torch import autograd
16
+ from torch.autograd import Variable
17
+ from torch.nn import init
18
+
19
+ from climategan.utils import all_texts_to_array
20
+
21
+
22
+ def transforms_string(ts):
23
+ return " -> ".join([t.__class__.__name__ for t in ts.transforms])
24
+
25
+
26
+ def init_weights(net, init_type="normal", init_gain=0.02, verbose=0, caller=""):
27
+ """Initialize network weights.
28
+ Parameters:
29
+ net (network) -- network to be initialized
30
+ init_type (str) -- the name of an initialization method:
31
+ normal | xavier | kaiming | orthogonal
32
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
33
+
34
+ We use 'normal' in the original pix2pix and CycleGAN paper.
35
+ But xavier and kaiming might work better for some applications.
36
+ Feel free to try yourself.
37
+ """
38
+
39
+ if not init_type:
40
+ print(
41
+ "init_weights({}): init_type is {}, defaulting to normal".format(
42
+ caller + " " + net.__class__.__name__, init_type
43
+ )
44
+ )
45
+ init_type = "normal"
46
+ if not init_gain:
47
+ print(
48
+ "init_weights({}): init_gain is {}, defaulting to normal".format(
49
+ caller + " " + net.__class__.__name__, init_type
50
+ )
51
+ )
52
+ init_gain = 0.02
53
+
54
+ def init_func(m):
55
+ classname = m.__class__.__name__
56
+ if classname.find("BatchNorm2d") != -1:
57
+ if hasattr(m, "weight") and m.weight is not None:
58
+ init.normal_(m.weight.data, 1.0, init_gain)
59
+ if hasattr(m, "bias") and m.bias is not None:
60
+ init.constant_(m.bias.data, 0.0)
61
+ elif hasattr(m, "weight") and (
62
+ classname.find("Conv") != -1 or classname.find("Linear") != -1
63
+ ):
64
+ if init_type == "normal":
65
+ init.normal_(m.weight.data, 0.0, init_gain)
66
+ elif init_type == "xavier":
67
+ init.xavier_normal_(m.weight.data, gain=init_gain)
68
+ elif init_type == "xavier_uniform":
69
+ init.xavier_uniform_(m.weight.data, gain=1.0)
70
+ elif init_type == "kaiming":
71
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
72
+ elif init_type == "orthogonal":
73
+ init.orthogonal_(m.weight.data, gain=init_gain)
74
+ elif init_type == "none": # uses pytorch's default init method
75
+ m.reset_parameters()
76
+ else:
77
+ raise NotImplementedError(
78
+ "initialization method [%s] is not implemented" % init_type
79
+ )
80
+ if hasattr(m, "bias") and m.bias is not None:
81
+ init.constant_(m.bias.data, 0.0)
82
+
83
+ if verbose > 0:
84
+ print("initialize %s with %s" % (net.__class__.__name__, init_type))
85
+ net.apply(init_func)
86
+
87
+
88
+ def domains_to_class_tensor(domains, one_hot=False):
89
+ """Converts a list of strings to a 1D Tensor representing the domains
90
+
91
+ domains_to_class_tensor(["sf", "rn"])
92
+ >>> torch.Tensor([2, 1])
93
+
94
+ Args:
95
+ domain (list(str)): each element of the list should be in {rf, rn, sf, sn}
96
+ one_hot (bool, optional): whether or not to 1-h encode class labels.
97
+ Defaults to False.
98
+ Raises:
99
+ ValueError: One of the domains listed is not in {rf, rn, sf, sn}
100
+
101
+ Returns:
102
+ torch.Tensor: 1D tensor mapping a domain to an int (not 1-hot) or 1-hot
103
+ domain labels in a 2D tensor
104
+ """
105
+
106
+ mapping = {"r": 0, "s": 1}
107
+
108
+ if not all(domain in mapping for domain in domains):
109
+ raise ValueError(
110
+ "Unknown domains {} should be in {}".format(domains, list(mapping.keys()))
111
+ )
112
+
113
+ target = torch.tensor([mapping[domain] for domain in domains])
114
+
115
+ if one_hot:
116
+ one_hot_target = torch.FloatTensor(len(target), 2) # 2 domains
117
+ one_hot_target.zero_()
118
+ one_hot_target.scatter_(1, target.unsqueeze(1), 1)
119
+ # https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507
120
+ target = one_hot_target
121
+ return target
122
+
123
+
124
+ def fake_domains_to_class_tensor(domains, one_hot=False):
125
+ """Converts a list of strings to a 1D Tensor representing the fake domains
126
+ (real or sim only)
127
+
128
+ fake_domains_to_class_tensor(["s", "r"], False)
129
+ >>> torch.Tensor([0, 2])
130
+
131
+
132
+ Args:
133
+ domain (list(str)): each element of the list should be in {r, s}
134
+ one_hot (bool, optional): whether or not to 1-h encode class labels.
135
+ Defaults to False.
136
+ Raises:
137
+ ValueError: One of the domains listed is not in {rf, rn, sf, sn}
138
+
139
+ Returns:
140
+ torch.Tensor: 1D tensor mapping a domain to an int (not 1-hot) or
141
+ a 2D tensor filled with 0.25 to fool the classifier (equiprobability
142
+ for each domain).
143
+ """
144
+ if one_hot:
145
+ target = torch.FloatTensor(len(domains), 2)
146
+ target.fill_(0.5)
147
+
148
+ else:
149
+ mapping = {"r": 1, "s": 0}
150
+
151
+ if not all(domain in mapping for domain in domains):
152
+ raise ValueError(
153
+ "Unknown domains {} should be in {}".format(
154
+ domains, list(mapping.keys())
155
+ )
156
+ )
157
+
158
+ target = torch.tensor([mapping[domain] for domain in domains])
159
+ return target
160
+
161
+
162
+ def show_tanh_tensor(tensor):
163
+ import skimage
164
+
165
+ if isinstance(tensor, torch.Tensor):
166
+ image = tensor.permute(1, 2, 0).detach().numpy()
167
+ else:
168
+ image = tensor
169
+ if image.shape[-1] != 3:
170
+ image = image.transpose(1, 2, 0)
171
+
172
+ if image.min() < 0 and image.min() > -1:
173
+ image = image / 2 + 0.5
174
+ elif image.min() < -1:
175
+ raise ValueError("can't handle this data")
176
+
177
+ skimage.io.imshow(image)
178
+
179
+
180
+ def normalize_tensor(t):
181
+ """
182
+ Brings any tensor to the [0; 1] range.
183
+
184
+ Args:
185
+ t (torch.Tensor): input to normalize
186
+
187
+ Returns:
188
+ torch.Tensor: t projected to [0; 1]
189
+ """
190
+ t = t - torch.min(t)
191
+ t = t / torch.max(t)
192
+ return t
193
+
194
+
195
+ def get_normalized_depth_t(tensor, domain, normalize=False, log=True):
196
+ assert not (normalize and log)
197
+ if domain == "r":
198
+ # megadepth depth
199
+ tensor = tensor.unsqueeze(0)
200
+ tensor = tensor - torch.min(tensor)
201
+ tensor = torch.true_divide(tensor, torch.max(tensor))
202
+
203
+ elif domain == "s":
204
+ # from 3-channel depth encoding from Unity simulator to 1-channel [0-1] values
205
+ tensor = decode_unity_depth_t(tensor, log=log, normalize=normalize)
206
+
207
+ elif domain == "kitti":
208
+ tensor = tensor / 100
209
+ if not log:
210
+ tensor = 1 / tensor
211
+ if normalize:
212
+ tensor = tensor - tensor.min()
213
+ tensor = tensor / tensor.max()
214
+ else:
215
+ tensor = torch.log(tensor)
216
+
217
+ tensor = tensor.unsqueeze(0)
218
+
219
+ return tensor
220
+
221
+
222
+ def decode_bucketed_depth(tensor, opts):
223
+ # tensor is size 1 x C x H x W
224
+ assert tensor.shape[0] == 1
225
+ idx = torch.argmax(tensor.squeeze(0), dim=0) # channels become dim 0 with squeeze
226
+ linspace_args = (
227
+ opts.gen.d.classify.linspace.min,
228
+ opts.gen.d.classify.linspace.max,
229
+ opts.gen.d.classify.linspace.buckets,
230
+ )
231
+ indexer = torch.linspace(*linspace_args)
232
+ log_depth = indexer[idx.long()].to(torch.float32) # H x W
233
+ depth = torch.exp(log_depth)
234
+ return depth.unsqueeze(0).unsqueeze(0).to(tensor.device)
235
+
236
+
237
+ def decode_unity_depth_t(unity_depth, log=True, normalize=False, numpy=False, far=1000):
238
+ """Transforms the 3-channel encoded depth map from our Unity simulator
239
+ to 1-channel depth map containing metric depth values.
240
+ The depth is encoded in the following way:
241
+ - The information from the simulator is (1 - LinearDepth (in [0,1])).
242
+ far corresponds to the furthest distance to the camera included in the
243
+ depth map.
244
+ LinearDepth * far gives the real metric distance to the camera.
245
+ - depth is first divided in 31 slices encoded in R channel with values ranging
246
+ from 0 to 247
247
+ - each slice is divided again in 31 slices, whose value is encoded in G channel
248
+ - each of the G slices is divided into 256 slices, encoded in B channel
249
+
250
+ In total, we have a discretization of depth into N = 31*31*256 - 1 possible values,
251
+ covering a range of far/N meters.
252
+
253
+ Note that, what we encode here is 1 - LinearDepth so that the furthest point is
254
+ [0,0,0] (that is sky) and the closest point[255,255,255]
255
+
256
+ The metric distance associated to a pixel whose depth is (R,G,B) is :
257
+ d = (far/N) * [((255 - R)//8)*256*31 + ((255 - G)//8)*256 + (255 - B)]
258
+
259
+ * torch.Tensor in [0, 1] as torch.float32 if numpy == False
260
+
261
+ * else numpy.array in [0, 255] as np.uint8
262
+
263
+ Args:
264
+ unity_depth (torch.Tensor): one depth map obtained from our simulator
265
+ numpy (bool, optional): Whether to return a float tensor or an int array.
266
+ Defaults to False.
267
+ far: far parameter of the camera in Unity simulator.
268
+
269
+ Returns:
270
+ [torch.Tensor or numpy.array]: decoded depth
271
+ """
272
+ R = unity_depth[:, :, 0]
273
+ G = unity_depth[:, :, 1]
274
+ B = unity_depth[:, :, 2]
275
+
276
+ R = ((247 - R) / 8).type(torch.IntTensor)
277
+ G = ((247 - G) / 8).type(torch.IntTensor)
278
+ B = (255 - B).type(torch.IntTensor)
279
+ depth = ((R * 256 * 31 + G * 256 + B).type(torch.FloatTensor)) / (256 * 31 * 31 - 1)
280
+ depth = depth * far
281
+ if not log:
282
+ depth = 1 / depth
283
+ depth = depth.unsqueeze(0) # (depth * far).unsqueeze(0)
284
+
285
+ if log:
286
+ depth = torch.log(depth)
287
+ if normalize:
288
+ depth = depth - torch.min(depth)
289
+ depth /= torch.max(depth)
290
+ if numpy:
291
+ depth = depth.data.cpu().numpy()
292
+ return depth.astype(np.uint8).squeeze()
293
+ return depth
294
+
295
+
296
+ def to_inv_depth(log_depth, numpy=False):
297
+ """Convert log depth tensor to inverse depth image for display
298
+
299
+ Args:
300
+ depth (Tensor): log depth float tensor
301
+ """
302
+ depth = torch.exp(log_depth)
303
+ # visualize prediction using inverse depth, so that we don't need sky
304
+ # segmentation (if you want to use RGB map for visualization,
305
+ # you have to run semantic segmentation to mask the sky first
306
+ # since the depth of sky is random from CNN)
307
+ inv_depth = 1 / depth
308
+ inv_depth /= torch.max(inv_depth)
309
+ if numpy:
310
+ inv_depth = inv_depth.data.cpu().numpy()
311
+ # you might also use percentile for better visualization
312
+
313
+ return inv_depth
314
+
315
+
316
+ def shuffle_batch_tuple(mbt):
317
+ """shuffle the order of domains in the batch
318
+
319
+ Args:
320
+ mbt (tuple): multi-batch tuple
321
+
322
+ Returns:
323
+ list: randomized list of domain-specific batches
324
+ """
325
+ assert isinstance(mbt, (tuple, list))
326
+ assert len(mbt) > 0
327
+ perm = np.random.permutation(len(mbt))
328
+ return [mbt[i] for i in perm]
329
+
330
+
331
+ def slice_batch(batch, slice_size):
332
+ assert slice_size > 0
333
+ for k, v in batch.items():
334
+ if isinstance(v, dict):
335
+ for task, d in v.items():
336
+ batch[k][task] = d[:slice_size]
337
+ else:
338
+ batch[k] = v[:slice_size]
339
+ return batch
340
+
341
+
342
+ def save_tanh_tensor(image, path):
343
+ """Save an image which can be numpy or tensor, 2 or 3 dims (no batch)
344
+ to path.
345
+
346
+ Args:
347
+ image (np.array or torch.Tensor): image to save
348
+ path (pathlib.Path or str): where to save the image
349
+ """
350
+ path = Path(path)
351
+ if isinstance(image, torch.Tensor):
352
+ image = image.detach().cpu().numpy()
353
+ if image.shape[-1] != 3 and image.shape[0] == 3:
354
+ image = np.transpose(image, (1, 2, 0))
355
+ if image.min() < 0 and image.min() > -1:
356
+ image = image / 2 + 0.5
357
+ elif image.min() < -1:
358
+ image -= image.min()
359
+ image /= image.max()
360
+ # print("Warning: scaling image data in save_tanh_tensor")
361
+
362
+ skio.imsave(path, (image * 255).astype(np.uint8))
363
+
364
+
365
+ def save_batch(multi_domain_batch, root="./", step=0, num_threads=5):
366
+ root = Path(root)
367
+ root.mkdir(parents=True, exist_ok=True)
368
+ images_to_save = {"paths": [], "images": []}
369
+ for domain, batch in multi_domain_batch.items():
370
+ y = batch["data"].get("y")
371
+ x = batch["data"]["x"]
372
+ if y is not None:
373
+ paths = batch["paths"]["x"]
374
+ imtensor = torch.cat([x, y], dim=-1)
375
+ for i, im in enumerate(imtensor):
376
+ imid = Path(paths[i]).stem[:10]
377
+ images_to_save["paths"] += [
378
+ root / "im_{}_{}_{}.png".format(step, domain, imid)
379
+ ]
380
+ images_to_save["images"].append(im)
381
+ if num_threads > 0:
382
+ threaded_write(images_to_save["images"], images_to_save["paths"], num_threads)
383
+ else:
384
+ for im, path in zip(images_to_save["images"], images_to_save["paths"]):
385
+ save_tanh_tensor(im, path)
386
+
387
+
388
+ def threaded_write(images, paths, num_threads=5):
389
+ t_im = []
390
+ t_p = []
391
+ for im, p in zip(images, paths):
392
+ t_im.append(im)
393
+ t_p.append(p)
394
+ if len(t_im) == num_threads:
395
+ ts = [
396
+ Thread(target=save_tanh_tensor, args=(_i, _p))
397
+ for _i, _p in zip(t_im, t_p)
398
+ ]
399
+ list(map(lambda t: t.start(), ts))
400
+ list(map(lambda t: t.join(), ts))
401
+ t_im = []
402
+ t_p = []
403
+ if t_im:
404
+ ts = [
405
+ Thread(target=save_tanh_tensor, args=(_i, _p)) for _i, _p in zip(t_im, t_p)
406
+ ]
407
+ list(map(lambda t: t.start(), ts))
408
+ list(map(lambda t: t.join(), ts))
409
+
410
+
411
+ def get_num_params(model):
412
+ total_params = sum(p.numel() for p in model.parameters())
413
+ return total_params
414
+
415
+
416
+ def vgg_preprocess(batch):
417
+ """Preprocess batch to use VGG model"""
418
+ tensortype = type(batch.data)
419
+ (r, g, b) = torch.chunk(batch, 3, dim=1)
420
+ batch = torch.cat((b, g, r), dim=1) # convert RGB to BGR
421
+ batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
422
+ mean = tensortype(batch.data.size()).cuda()
423
+ mean[:, 0, :, :] = 103.939
424
+ mean[:, 1, :, :] = 116.779
425
+ mean[:, 2, :, :] = 123.680
426
+ batch = batch.sub(Variable(mean)) # subtract mean
427
+ return batch
428
+
429
+
430
+ def zero_grad(model: nn.Module):
431
+ """
432
+ Sets gradients to None. Mode efficient than model.zero_grad()
433
+ or opt.zero_grad() according to https://www.youtube.com/watch?v=9mS1fIYj1So
434
+
435
+ Args:
436
+ model (nn.Module): model to zero out
437
+ """
438
+ for p in model.parameters():
439
+ p.grad = None
440
+
441
+
442
+ # Take the prediction of fake and real images from the combined batch
443
+ def divide_pred(disc_output):
444
+ """
445
+ Divide a multiscale discriminator's output into 2 sets of tensors,
446
+ expecting the input to the discriminator to be a concatenation
447
+ on the batch axis of real and fake (or fake and real) images,
448
+ effectively doubling the batch size for better batchnorm statistics
449
+
450
+ Args:
451
+ disc_output (list | torch.Tensor): Discriminator output to split
452
+
453
+ Returns:
454
+ list | torch.Tensor[type]: pair of split outputs
455
+ """
456
+ # https://github.com/NVlabs/SPADE/blob/master/models/pix2pix_model.py
457
+ # the prediction contains the intermediate outputs of multiscale GAN,
458
+ # so it's usually a list
459
+ if type(disc_output) == list:
460
+ half1 = []
461
+ half2 = []
462
+ for p in disc_output:
463
+ half1.append([tensor[: tensor.size(0) // 2] for tensor in p])
464
+ half2.append([tensor[tensor.size(0) // 2 :] for tensor in p])
465
+ else:
466
+ half1 = disc_output[: disc_output.size(0) // 2]
467
+ half2 = disc_output[disc_output.size(0) // 2 :]
468
+
469
+ return half1, half2
470
+
471
+
472
+ def is_tpu_available():
473
+ _torch_tpu_available = False
474
+ try:
475
+ import torch_xla.core.xla_model as xm # type: ignore
476
+
477
+ if "xla" in str(xm.xla_device()):
478
+ _torch_tpu_available = True
479
+ else:
480
+ _torch_tpu_available = False
481
+ except ImportError:
482
+ _torch_tpu_available = False
483
+
484
+ return _torch_tpu_available
485
+
486
+
487
+ def get_WGAN_gradient(input, output):
488
+ # github code reference:
489
+ # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
490
+ # Calculate the gradient that WGAN-gp needs
491
+ grads = autograd.grad(
492
+ outputs=output,
493
+ inputs=input,
494
+ grad_outputs=torch.ones(output.size()).cuda(),
495
+ create_graph=True,
496
+ retain_graph=True,
497
+ only_inputs=True,
498
+ )[0]
499
+ grads = grads.view(grads.size(0), -1)
500
+ gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
501
+ return gp
502
+
503
+
504
+ def print_num_parameters(trainer, force=False):
505
+ if trainer.verbose == 0 and not force:
506
+ return
507
+ print("-" * 35)
508
+ if trainer.G.encoder is not None:
509
+ print(
510
+ "{:21}:".format("num params encoder"),
511
+ f"{get_num_params(trainer.G.encoder):12,}",
512
+ )
513
+ for d in trainer.G.decoders.keys():
514
+ print(
515
+ "{:21}:".format(f"num params decoder {d}"),
516
+ f"{get_num_params(trainer.G.decoders[d]):12,}",
517
+ )
518
+
519
+ print(
520
+ "{:21}:".format("num params painter"),
521
+ f"{get_num_params(trainer.G.painter):12,}",
522
+ )
523
+
524
+ if trainer.D is not None:
525
+ for d in trainer.D.keys():
526
+ print(
527
+ "{:21}:".format(f"num params discrim {d}"),
528
+ f"{get_num_params(trainer.D[d]):12,}",
529
+ )
530
+
531
+ print("-" * 35)
532
+
533
+
534
+ def srgb2lrgb(x):
535
+ x = normalize(x)
536
+ im = ((x + 0.055) / 1.055) ** (2.4)
537
+ im[x <= 0.04045] = x[x <= 0.04045] / 12.92
538
+ return im
539
+
540
+
541
+ def lrgb2srgb(ims):
542
+ if len(ims.shape) == 3:
543
+ ims = [ims]
544
+ stack = False
545
+ else:
546
+ ims = list(ims)
547
+ stack = True
548
+
549
+ outs = []
550
+ for im in ims:
551
+
552
+ out = torch.zeros_like(im)
553
+ for k in range(3):
554
+ temp = im[k, :, :]
555
+
556
+ out[k, :, :] = 12.92 * temp * (temp <= 0.0031308) + (
557
+ 1.055 * torch.pow(temp, (1 / 2.4)) - 0.055
558
+ ) * (temp > 0.0031308)
559
+ outs.append(out)
560
+
561
+ if stack:
562
+ return torch.stack(outs)
563
+
564
+ return outs[0]
565
+
566
+
567
+ def normalize(t, mini=0, maxi=1):
568
+ if len(t.shape) == 3:
569
+ return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
570
+
571
+ batch_size = t.shape[0]
572
+ min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, 1, 1, 1)
573
+ t = t - min_t
574
+ max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, 1, 1, 1)
575
+ t = t / max_t
576
+ return mini + (maxi - mini) * t
577
+
578
+
579
+ def retrieve_sky_mask(seg):
580
+ """
581
+ get the binary mask for the sky given a segmentation tensor
582
+ of logits (N x C x H x W) or labels (N x H x W)
583
+
584
+ Args:
585
+ seg (torch.Tensor): Segmentation map
586
+
587
+ Returns:
588
+ torch.Tensor: Sky mask
589
+ """
590
+ if len(seg.shape) == 4: # Predictions
591
+ seg_ind = torch.argmax(seg, dim=1)
592
+ else:
593
+ seg_ind = seg
594
+
595
+ sky_mask = seg_ind == 9
596
+ return sky_mask
597
+
598
+
599
+ def all_texts_to_tensors(texts, width=640, height=40):
600
+ """
601
+ Creates a list of tensors with texts from PIL images
602
+
603
+ Args:
604
+ texts (list(str)): texts to write
605
+ width (int, optional): width of individual texts. Defaults to 640.
606
+ height (int, optional): height of individual texts. Defaults to 40.
607
+
608
+ Returns:
609
+ list(torch.Tensor): len(texts) tensors 3 x height x width
610
+ """
611
+ arrays = all_texts_to_array(texts, width, height)
612
+ arrays = [array.transpose(2, 0, 1) for array in arrays]
613
+ return [torch.tensor(array) for array in arrays]
614
+
615
+
616
+ def write_architecture(trainer):
617
+ stem = "archi"
618
+ out = Path(trainer.opts.output_path)
619
+
620
+ # encoder
621
+ with open(out / f"{stem}_encoder.txt", "w") as f:
622
+ f.write(str(trainer.G.encoder))
623
+
624
+ # decoders
625
+ for k, v in trainer.G.decoders.items():
626
+ with open(out / f"{stem}_decoder_{k}.txt", "w") as f:
627
+ f.write(str(v))
628
+
629
+ # painter
630
+ if get_num_params(trainer.G.painter) > 0:
631
+ with open(out / f"{stem}_painter.txt", "w") as f:
632
+ f.write(str(trainer.G.painter))
633
+
634
+ # discriminators
635
+ if get_num_params(trainer.D) > 0:
636
+ for k, v in trainer.D.items():
637
+ with open(out / f"{stem}_discriminator_{k}.txt", "w") as f:
638
+ f.write(str(v))
639
+
640
+ with io.StringIO() as buf, redirect_stdout(buf):
641
+ print_num_parameters(trainer)
642
+ output = buf.getvalue()
643
+ with open(out / "archi_num_params.txt", "w") as f:
644
+ f.write(output)
645
+
646
+
647
+ def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
648
+ delta = (res[0] / shape[0], res[1] / shape[1])
649
+ d = (shape[0] // res[0], shape[1] // res[1])
650
+
651
+ grid = (
652
+ torch.stack(
653
+ torch.meshgrid(
654
+ torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])
655
+ ),
656
+ dim=-1,
657
+ )
658
+ % 1
659
+ )
660
+ angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
661
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
662
+
663
+ tile_grads = (
664
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
665
+ .repeat_interleave(d[0], 0)
666
+ .repeat_interleave(d[1], 1)
667
+ )
668
+ dot = lambda grad, shift: ( # noqa: E731
669
+ torch.stack(
670
+ (
671
+ grid[: shape[0], : shape[1], 0] + shift[0],
672
+ grid[: shape[0], : shape[1], 1] + shift[1],
673
+ ),
674
+ dim=-1,
675
+ )
676
+ * grad[: shape[0], : shape[1]]
677
+ ).sum(dim=-1)
678
+
679
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
680
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
681
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
682
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
683
+ t = fade(grid[: shape[0], : shape[1]])
684
+ return math.sqrt(2) * torch.lerp(
685
+ torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
686
+ )
687
+
688
+
689
+ def mix_noise(x, mask, res=(8, 3), weight=0.1):
690
+ noise = rand_perlin_2d(x.shape[-2:], res).unsqueeze(0).unsqueeze(0).to(x.device)
691
+ noise = noise - noise.min()
692
+ mask = mask.repeat(1, 3, 1, 1).to(x.device).to(torch.float16)
693
+ y = mask * (weight * noise + (1 - weight) * x) + (1 - mask) * x
694
+ return y
695
+
696
+
697
+ def tensor_ims_to_np_uint8s(ims):
698
+ """
699
+ transform a CHW of NCHW tensor into a list of np.uint8 [0, 255]
700
+ image arrays
701
+
702
+ Args:
703
+ ims (torch.Tensor | list): [description]
704
+ """
705
+ if not isinstance(ims, list):
706
+ assert isinstance(ims, torch.Tensor)
707
+ if ims.ndim == 3:
708
+ ims = [ims]
709
+
710
+ nps = []
711
+ for t in ims:
712
+ if t.shape[0] == 3:
713
+ t = t.permute(1, 2, 0)
714
+ else:
715
+ assert t.shape[-1] == 3
716
+
717
+ n = t.cpu().numpy()
718
+ n = (n + 1) / 2 * 255
719
+ nps.append(n.astype(np.uint8))
720
+
721
+ return nps[0] if len(nps) == 1 else nps
climategan/utils.py ADDED
@@ -0,0 +1,1063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """All non-tensor utils
2
+ """
3
+ import contextlib
4
+ import datetime
5
+ import json
6
+ import os
7
+ import re
8
+ import shutil
9
+ import subprocess
10
+ import time
11
+ import traceback
12
+ from os.path import expandvars
13
+ from pathlib import Path
14
+ from typing import Any, List, Optional, Union
15
+ from uuid import uuid4
16
+
17
+ import numpy as np
18
+ import torch
19
+ import yaml
20
+ from addict import Dict
21
+ from comet_ml import Experiment
22
+
23
+ comet_kwargs = {
24
+ "auto_metric_logging": False,
25
+ "parse_args": True,
26
+ "log_env_gpu": True,
27
+ "log_env_cpu": True,
28
+ "display_summary_level": 0,
29
+ }
30
+
31
+ IMG_EXTENSIONS = set(
32
+ [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP"]
33
+ )
34
+
35
+
36
+ def resolve(path):
37
+ """
38
+ fully resolve a path:
39
+ resolve env vars ($HOME etc.) -> expand user (~) -> make absolute
40
+
41
+ Returns:
42
+ pathlib.Path: resolved absolute path
43
+ """
44
+ return Path(expandvars(str(path))).expanduser().resolve()
45
+
46
+
47
+ def copy_run_files(opts: Dict) -> None:
48
+ """
49
+ Copy the opts's sbatch_file to output_path
50
+
51
+ Args:
52
+ opts (addict.Dict): options
53
+ """
54
+ if opts.sbatch_file:
55
+ p = resolve(opts.sbatch_file)
56
+ if p.exists():
57
+ o = resolve(opts.output_path)
58
+ if o.exists():
59
+ shutil.copyfile(p, o / p.name)
60
+ if opts.exp_file:
61
+ p = resolve(opts.exp_file)
62
+ if p.exists():
63
+ o = resolve(opts.output_path)
64
+ if o.exists():
65
+ shutil.copyfile(p, o / p.name)
66
+
67
+
68
+ def merge(
69
+ source: Union[dict, Dict], destination: Union[dict, Dict]
70
+ ) -> Union[dict, Dict]:
71
+ """
72
+ run me with nosetests --with-doctest file.py
73
+ >>> a = { 'first' : { 'all_rows' : { 'pass' : 'dog', 'number' : '1' } } }
74
+ >>> b = { 'first' : { 'all_rows' : { 'fail' : 'cat', 'number' : '5' } } }
75
+ >>> merge(b, a) == {
76
+ 'first' : {
77
+ 'all_rows' : { '
78
+ pass' : 'dog',
79
+ 'fail' : 'cat',
80
+ 'number' : '5'
81
+ }
82
+ }
83
+ }
84
+ True
85
+ """
86
+ for key, value in source.items():
87
+ try:
88
+ if isinstance(value, dict):
89
+ # get node or create one
90
+ node = destination.setdefault(key, {})
91
+ merge(value, node)
92
+ else:
93
+ if isinstance(destination, dict):
94
+ destination[key] = value
95
+ else:
96
+ destination = {key: value}
97
+ except TypeError as e:
98
+ print(traceback.format_exc())
99
+ print(">>>", source)
100
+ print(">>>", destination)
101
+ print(">>>", key)
102
+ print(">>>", value)
103
+ raise Exception(e)
104
+
105
+ return destination
106
+
107
+
108
+ def load_opts(
109
+ path: Optional[Union[str, Path]] = None,
110
+ default: Optional[Union[str, Path, dict, Dict]] = None,
111
+ commandline_opts: Optional[Union[Dict, dict]] = None,
112
+ ) -> Dict:
113
+ """Loadsize a configuration Dict from 2 files:
114
+ 1. default files with shared values across runs and users
115
+ 2. an overriding file with run- and user-specific values
116
+
117
+ Args:
118
+ path (pathlib.Path): where to find the overriding configuration
119
+ default (pathlib.Path, optional): Where to find the default opts.
120
+ Defaults to None. In which case it is assumed to be a default config
121
+ which needs processing such as setting default values for lambdas and gen
122
+ fields
123
+
124
+ Returns:
125
+ addict.Dict: options dictionnary, with overwritten default values
126
+ """
127
+
128
+ if path is None and default is None:
129
+ path = (
130
+ resolve(Path(__file__)).parent.parent
131
+ / "shared"
132
+ / "trainer"
133
+ / "defaults.yaml"
134
+ )
135
+
136
+ if path:
137
+ path = resolve(path)
138
+
139
+ if default is None:
140
+ default_opts = {}
141
+ else:
142
+ if isinstance(default, (str, Path)):
143
+ with open(default, "r") as f:
144
+ default_opts = yaml.safe_load(f)
145
+ else:
146
+ default_opts = dict(default)
147
+
148
+ if path is None:
149
+ overriding_opts = {}
150
+ else:
151
+ with open(path, "r") as f:
152
+ overriding_opts = yaml.safe_load(f) or {}
153
+
154
+ opts = Dict(merge(overriding_opts, default_opts))
155
+
156
+ if commandline_opts is not None and isinstance(commandline_opts, dict):
157
+ opts = Dict(merge(commandline_opts, opts))
158
+
159
+ if opts.train.kitti.pretrained:
160
+ assert "kitti" in opts.data.files.train
161
+ assert "kitti" in opts.data.files.val
162
+ assert opts.train.kitti.epochs > 0
163
+
164
+ opts.domains = []
165
+ if "m" in opts.tasks or "s" in opts.tasks or "d" in opts.tasks:
166
+ opts.domains.extend(["r", "s"])
167
+ if "p" in opts.tasks:
168
+ opts.domains.append("rf")
169
+ if opts.train.kitti.pretrain:
170
+ opts.domains.append("kitti")
171
+
172
+ opts.domains = list(set(opts.domains))
173
+
174
+ if "s" in opts.tasks:
175
+ if opts.gen.encoder.architecture != opts.gen.s.architecture:
176
+ print(
177
+ "WARNING: segmentation encoder and decoder architectures do not match"
178
+ )
179
+ print(
180
+ "Encoder: {} <> Decoder: {}".format(
181
+ opts.gen.encoder.architecture, opts.gen.s.architecture
182
+ )
183
+ )
184
+ if opts.gen.m.use_spade:
185
+ if "d" not in opts.tasks or "s" not in opts.tasks:
186
+ raise ValueError(
187
+ "opts.gen.m.use_spade is True so tasks MUST include"
188
+ + "both d and s, but received {}".format(opts.tasks)
189
+ )
190
+ if opts.gen.d.classify.enable:
191
+ raise ValueError(
192
+ "opts.gen.m.use_spade is True but using D as a classifier"
193
+ + " which is a non-implemented combination"
194
+ )
195
+
196
+ if opts.gen.s.depth_feat_fusion is True or opts.gen.s.depth_dada_fusion is True:
197
+ opts.gen.s.use_dada = True
198
+
199
+ events_path = (
200
+ resolve(Path(__file__)).parent.parent / "shared" / "trainer" / "events.yaml"
201
+ )
202
+ if events_path.exists():
203
+ with events_path.open("r") as f:
204
+ events_dict = yaml.safe_load(f)
205
+ events_dict = Dict(events_dict)
206
+ opts.events = events_dict
207
+
208
+ return set_data_paths(opts)
209
+
210
+
211
+ def set_data_paths(opts: Dict) -> Dict:
212
+ """Update the data files paths in data.files.train and data.files.val
213
+ from data.files.base
214
+
215
+ Args:
216
+ opts (addict.Dict): options
217
+
218
+ Returns:
219
+ addict.Dict: updated options
220
+ """
221
+
222
+ for mode in ["train", "val"]:
223
+ for domain in opts.data.files[mode]:
224
+ if opts.data.files.base and not opts.data.files[mode][domain].startswith(
225
+ "/"
226
+ ):
227
+ opts.data.files[mode][domain] = str(
228
+ Path(opts.data.files.base) / opts.data.files[mode][domain]
229
+ )
230
+ assert Path(
231
+ opts.data.files[mode][domain]
232
+ ).exists(), "Cannot find {}".format(str(opts.data.files[mode][domain]))
233
+
234
+ return opts
235
+
236
+
237
+ def load_test_opts(test_file_path: str = "config/trainer/local_tests.yaml") -> Dict:
238
+ """Returns the special opts set up for local tests
239
+ Args:
240
+ test_file_path (str, optional): Name of the file located in config/
241
+ Defaults to "local_tests.yaml".
242
+
243
+ Returns:
244
+ addict.Dict: Opts loaded from defaults.yaml and updated from test_file_path
245
+ """
246
+ return load_opts(
247
+ Path(__file__).parent.parent / f"{test_file_path}",
248
+ default=Path(__file__).parent.parent / "shared/trainer/defaults.yaml",
249
+ )
250
+
251
+
252
+ def get_git_revision_hash() -> str:
253
+ """Get current git hash the code is run from
254
+
255
+ Returns:
256
+ str: git hash
257
+ """
258
+ try:
259
+ return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
260
+ except Exception as e:
261
+ return str(e)
262
+
263
+
264
+ def get_git_branch() -> str:
265
+ """Get current git branch name
266
+
267
+ Returns:
268
+ str: git branch name
269
+ """
270
+ try:
271
+ return (
272
+ subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
273
+ .decode()
274
+ .strip()
275
+ )
276
+ except Exception as e:
277
+ return str(e)
278
+
279
+
280
+ def kill_job(id: Union[int, str]) -> None:
281
+ subprocess.check_output(["scancel", str(id)])
282
+
283
+
284
+ def write_hash(path: Union[str, Path]) -> None:
285
+ hash_code = get_git_revision_hash()
286
+ with open(path, "w") as f:
287
+ f.write(hash_code)
288
+
289
+
290
+ def shortuid():
291
+ return str(uuid4()).split("-")[0]
292
+
293
+
294
+ def datenowshort():
295
+ """
296
+ >>> a = str(datetime.datetime.now())
297
+ >>> print(a)
298
+ '2021-02-25 11:34:50.188072'
299
+ >>> print(a[5:].split(".")[0].replace(" ", "_"))
300
+ '02-25_11:35:41'
301
+
302
+ Returns:
303
+ str: month-day_h:m:s
304
+ """
305
+ return str(datetime.datetime.now())[5:].split(".")[0].replace(" ", "_")
306
+
307
+
308
+ def get_increased_path(path: Union[str, Path], use_date: bool = False) -> Path:
309
+ """Returns an increased path: if dir exists, returns `dir (1)`.
310
+ If `dir (i)` exists, returns `dir (max(i) + 1)`
311
+
312
+ get_increased_path("test").mkdir() creates `test/`
313
+ then
314
+ get_increased_path("test").mkdir() creates `test (1)/`
315
+ etc.
316
+ if `test (3)/` exists but not `test (2)/`, `test (4)/` is created so that indexes
317
+ always increase
318
+
319
+ Args:
320
+ path (str or pathlib.Path): the file/directory which may already exist and would
321
+ need to be increased
322
+
323
+ Returns:
324
+ pathlib.Path: increased path
325
+ """
326
+ fp = resolve(path)
327
+ if not fp.exists():
328
+ return fp
329
+
330
+ if fp.is_file():
331
+ if not use_date:
332
+ while fp.exists():
333
+ fp = fp.parent / f"{fp.stem}--{shortuid()}{fp.suffix}"
334
+ return fp
335
+ else:
336
+ while fp.exists():
337
+ time.sleep(0.5)
338
+ fp = fp.parent / f"{fp.stem}--{datenowshort()}{fp.suffix}"
339
+ return fp
340
+
341
+ if not use_date:
342
+ while fp.exists():
343
+ fp = fp.parent / f"{fp.name}--{shortuid()}"
344
+ return fp
345
+ else:
346
+ while fp.exists():
347
+ time.sleep(0.5)
348
+ fp = fp.parent / f"{fp.name}--{datenowshort()}"
349
+ return fp
350
+
351
+ # vals = []
352
+ # for n in fp.parent.glob("{}*".format(fp.stem)):
353
+ # if re.match(r".+\(\d+\)", str(n.name)) is not None:
354
+ # name = str(n.name)
355
+ # start = name.index("(")
356
+ # end = name.index(")")
357
+ # vals.append(int(name[start + 1 : end]))
358
+ # if vals:
359
+ # ext = " ({})".format(max(vals) + 1)
360
+ # elif fp.exists():
361
+ # ext = " (1)"
362
+ # else:
363
+ # ext = ""
364
+ # return fp.parent / (fp.stem + ext + fp.suffix)
365
+
366
+
367
+ def env_to_path(path: str) -> str:
368
+ """Transorms an environment variable mention in a json
369
+ into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds
370
+
371
+ Args:
372
+ path (str): path potentially containing the env variable
373
+
374
+ """
375
+ path_elements = path.split("/")
376
+ new_path = []
377
+ for el in path_elements:
378
+ if "$" in el:
379
+ new_path.append(os.environ[el.replace("$", "")])
380
+ else:
381
+ new_path.append(el)
382
+ return "/".join(new_path)
383
+
384
+
385
+ def flatten_opts(opts: Dict) -> dict:
386
+ """Flattens a multi-level addict.Dict or native dictionnary into a single
387
+ level native dict with string keys representing the keys sequence to reach
388
+ a value in the original argument.
389
+
390
+ d = addict.Dict()
391
+ d.a.b.c = 2
392
+ d.a.b.d = 3
393
+ d.a.e = 4
394
+ d.f = 5
395
+ flatten_opts(d)
396
+ >>> {
397
+ "a.b.c": 2,
398
+ "a.b.d": 3,
399
+ "a.e": 4,
400
+ "f": 5,
401
+ }
402
+
403
+ Args:
404
+ opts (addict.Dict or dict): addict dictionnary to flatten
405
+
406
+ Returns:
407
+ dict: flattened dictionnary
408
+ """
409
+ values_list = []
410
+
411
+ def p(d, prefix="", vals=[]):
412
+ for k, v in d.items():
413
+ if isinstance(v, (Dict, dict)):
414
+ p(v, prefix + k + ".", vals)
415
+ elif isinstance(v, list):
416
+ if v and isinstance(v[0], (Dict, dict)):
417
+ for i, m in enumerate(v):
418
+ p(m, prefix + k + "." + str(i) + ".", vals)
419
+ else:
420
+ vals.append((prefix + k, str(v)))
421
+ else:
422
+ if isinstance(v, Path):
423
+ v = str(v)
424
+ vals.append((prefix + k, v))
425
+
426
+ p(opts, vals=values_list)
427
+ return dict(values_list)
428
+
429
+
430
+ def get_comet_rest_api_key(
431
+ path_to_config_file: Optional[Union[str, Path]] = None
432
+ ) -> str:
433
+ """Gets a comet.ml rest_api_key in the following order:
434
+ * config file specified as argument
435
+ * environment variable
436
+ * .comet.config file in the current working diretory
437
+ * .comet.config file in your home
438
+
439
+ config files must have a line like `rest_api_key=<some api key>`
440
+
441
+ Args:
442
+ path_to_config_file (str or pathlib.Path, optional): config_file to use.
443
+ Defaults to None.
444
+
445
+ Raises:
446
+ ValueError: can't find a file
447
+ ValueError: can't find the key in a file
448
+
449
+ Returns:
450
+ str: your comet rest_api_key
451
+ """
452
+ if "COMET_REST_API_KEY" in os.environ and path_to_config_file is None:
453
+ return os.environ["COMET_REST_API_KEY"]
454
+ if path_to_config_file is not None:
455
+ p = resolve(path_to_config_file)
456
+ else:
457
+ p = Path() / ".comet.config"
458
+ if not p.exists():
459
+ p = Path.home() / ".comet.config"
460
+ if not p.exists():
461
+ raise ValueError("Unable to find your COMET_REST_API_KEY")
462
+ with p.open("r") as f:
463
+ for keys in f:
464
+ if "rest_api_key" in keys:
465
+ return keys.strip().split("=")[-1].strip()
466
+ raise ValueError("Unable to find your COMET_REST_API_KEY in {}".format(str(p)))
467
+
468
+
469
+ def get_files(dirName: str) -> list:
470
+ # create a list of file and sub directories
471
+ files = sorted(os.listdir(dirName))
472
+ all_files = list()
473
+ for entry in files:
474
+ fullPath = os.path.join(dirName, entry)
475
+ if os.path.isdir(fullPath):
476
+ all_files = all_files + get_files(fullPath)
477
+ else:
478
+ all_files.append(fullPath)
479
+
480
+ return all_files
481
+
482
+
483
+ def make_json_file(
484
+ tasks: List[str],
485
+ addresses: List[str], # for windows user, use "\\" instead of using "/"
486
+ json_names: List[str] = ["train_jsonfile.json", "val_jsonfile.json"],
487
+ splitter: str = "/",
488
+ pourcentage_val: float = 0.15,
489
+ ) -> None:
490
+ """
491
+ How to use it?
492
+ e.g.
493
+ make_json_file(['x','m','d'], [
494
+ '/network/tmp1/ccai/data/munit_dataset/trainA_size_1200/',
495
+ '/network/tmp1/ccai/data/munit_dataset/seg_trainA_size_1200/',
496
+ '/network/tmp1/ccai/data/munit_dataset/trainA_megadepth_resized/'
497
+ ], ["train_r.json", "val_r.json"])
498
+
499
+ Args:
500
+ tasks (list): the list of image type like 'x', 'm', 'd', etc.
501
+ addresses (list): the list of the corresponding address of the
502
+ image type mentioned in tasks
503
+ json_names (list): names for the json files, train being first
504
+ (e.g. : ["train_r.json", "val_r.json"])
505
+ splitter (str, optional): The path separator for the current OS.
506
+ Defaults to '/'.
507
+ pourcentage_val: pourcentage of files to go in validation set
508
+ """
509
+ assert len(tasks) == len(addresses), "keys and addresses must have the same length!"
510
+
511
+ files = [get_files(addresses[j]) for j in range(len(tasks))]
512
+ n_files_val = int(pourcentage_val * len(files[0]))
513
+ n_files_train = len(files[0]) - n_files_val
514
+ filenames = [files[0][:n_files_train], files[0][-n_files_val:]]
515
+
516
+ file_address_map = {
517
+ tasks[j]: {
518
+ ".".join(file.split(splitter)[-1].split(".")[:-1]): file
519
+ for file in files[j]
520
+ }
521
+ for j in range(len(tasks))
522
+ }
523
+ # The tasks of the file_address_map are like 'x', 'm', 'd'...
524
+ # The values of the file_address_map are a dictionary whose tasks are the
525
+ # filenames without extension whose values are the path of the filename
526
+ # e.g. file_address_map =
527
+ # {'x': {'A': 'path/to/trainA_size_1200/A.png', ...},
528
+ # 'm': {'A': 'path/to/seg_trainA_size_1200/A.jpg',...}
529
+ # 'd': {'A': 'path/to/trainA_megadepth_resized/A.bmp',...}
530
+ # ...}
531
+
532
+ for i, json_name in enumerate(json_names):
533
+ dicts = []
534
+ for j in range(len(filenames[i])):
535
+ file = filenames[i][j]
536
+ filename = file.split(splitter)[-1] # the filename with 'x' extension
537
+ filename_ = ".".join(
538
+ filename.split(".")[:-1]
539
+ ) # the filename without extension
540
+ tmp_dict = {}
541
+ for k in range(len(tasks)):
542
+ tmp_dict[tasks[k]] = file_address_map[tasks[k]][filename_]
543
+ dicts.append(tmp_dict)
544
+ with open(json_name, "w", encoding="utf-8") as outfile:
545
+ json.dump(dicts, outfile, ensure_ascii=False)
546
+
547
+
548
+ def append_task_to_json(
549
+ path_to_json: Union[str, Path],
550
+ path_to_new_json: Union[str, Path],
551
+ path_to_new_images_dir: Union[str, Path],
552
+ new_task_name: str,
553
+ ):
554
+ """Add all files for a task to an existing json file by creating a new json file
555
+ in the specified path.
556
+ Assumes that the files for the new task have exactly the same names as the ones
557
+ for the other tasks
558
+
559
+ Args:
560
+ path_to_json: complete path to the json file to modify
561
+ path_to_new_json: complete path to the new json file to be created
562
+ path_to_new_images_dir: complete path of the directory where to find the
563
+ images for the new task
564
+ new_task_name: name of the new task
565
+
566
+ e.g:
567
+ append_json(
568
+ "/network/tmp1/ccai/data/climategan/seg/train_r.json",
569
+ "/network/tmp1/ccai/data/climategan/seg/train_r_new.json"
570
+ "/network/tmp1/ccai/data/munit_dataset/trainA_seg_HRNet/unity_labels",
571
+ "s",
572
+ )
573
+ """
574
+ ims_list = None
575
+ if path_to_json:
576
+ path_to_json = Path(path_to_json).resolve()
577
+ with open(path_to_json, "r") as f:
578
+ ims_list = json.load(f)
579
+
580
+ files = get_files(path_to_new_images_dir)
581
+
582
+ if ims_list is None:
583
+ raise ValueError(f"Could not find the list in {path_to_json}")
584
+
585
+ new_ims_list = [None] * len(ims_list)
586
+ for i, im_dict in enumerate(ims_list):
587
+ new_ims_list[i] = {}
588
+ for task, path in im_dict.items():
589
+ new_ims_list[i][task] = path
590
+
591
+ for i, im_dict in enumerate(ims_list):
592
+ for task, path in im_dict.items():
593
+ file_name = os.path.splitext(path)[0] # removes extension
594
+ file_name = file_name.rsplit("/", 1)[-1] # only the file_name
595
+ file_found = False
596
+ for file_path in files:
597
+ if file_name in file_path:
598
+ file_found = True
599
+ new_ims_list[i][new_task_name] = file_path
600
+ break
601
+ if file_found:
602
+ break
603
+ else:
604
+ print("Error! File ", file_name, "not found in directory!")
605
+ return
606
+
607
+ with open(path_to_new_json, "w", encoding="utf-8") as f:
608
+ json.dump(new_ims_list, f, ensure_ascii=False)
609
+
610
+
611
+ def sum_dict(dict1: Union[dict, Dict], dict2: Union[Dict, dict]) -> Union[dict, Dict]:
612
+ """Add dict2 into dict1"""
613
+ for k, v in dict2.items():
614
+ if not isinstance(v, dict):
615
+ dict1[k] += v
616
+ else:
617
+ sum_dict(dict1[k], dict2[k])
618
+ return dict1
619
+
620
+
621
+ def div_dict(dict1: Union[dict, Dict], div_by: float) -> dict:
622
+ """Divide elements of dict1 by div_by"""
623
+ for k, v in dict1.items():
624
+ if not isinstance(v, dict):
625
+ dict1[k] /= div_by
626
+ else:
627
+ div_dict(dict1[k], div_by)
628
+ return dict1
629
+
630
+
631
+ def comet_id_from_url(url: str) -> Optional[str]:
632
+ """
633
+ Get comet exp id from its url:
634
+ https://www.comet.ml/vict0rsch/climategan/2a1a4a96afe848218c58ac4e47c5375f
635
+ -> 2a1a4a96afe848218c58ac4e47c5375f
636
+
637
+ Args:
638
+ url (str): comet exp url
639
+
640
+ Returns:
641
+ str: comet exp id
642
+ """
643
+ try:
644
+ ids = url.split("/")
645
+ ids = [i for i in ids if i]
646
+ return ids[-1]
647
+ except Exception:
648
+ return None
649
+
650
+
651
+ @contextlib.contextmanager
652
+ def temp_np_seed(seed: Optional[int]) -> None:
653
+ """
654
+ Set temporary numpy seed:
655
+ with temp_np_seed(123):
656
+ np.random.permutation(3)
657
+
658
+ Args:
659
+ seed (int): temporary numpy seed
660
+ """
661
+ state = np.random.get_state()
662
+ np.random.seed(seed)
663
+ try:
664
+ yield
665
+ finally:
666
+ np.random.set_state(state)
667
+
668
+
669
+ def get_display_indices(opts: Dict, domain: str, length: int) -> list:
670
+ """
671
+ Compute the index of images to use for comet logging:
672
+ if opts.comet.display_indices is an int, and domain is real:
673
+ return range(int)
674
+ if opts.comet.display_indices is an int, and domain is sim:
675
+ return permutation(length)[:int]
676
+ if opts.comet.display_indices is a list:
677
+ return list
678
+
679
+ otherwise return []
680
+
681
+
682
+ Args:
683
+ opts (addict.Dict): options
684
+ domain (str): domain for those indices
685
+ length (int): length of dataset for the permutation
686
+
687
+ Returns:
688
+ list(int): The indices to display
689
+ """
690
+ if domain == "rf":
691
+ dsize = max([opts.comet.display_size, opts.train.fid.get("n_images", 0)])
692
+ else:
693
+ dsize = opts.comet.display_size
694
+ if dsize > length:
695
+ print(
696
+ f"Warning: dataset is smaller ({length} images) "
697
+ + f"than required display indices ({dsize})."
698
+ + f" Selecting {length} images."
699
+ )
700
+
701
+ display_indices = []
702
+ assert isinstance(dsize, (int, list)), "Unknown display size {}".format(dsize)
703
+ if isinstance(dsize, int):
704
+ assert dsize >= 0, "Display size cannot be < 0"
705
+ with temp_np_seed(123):
706
+ display_indices = list(np.random.permutation(length)[:dsize])
707
+ elif isinstance(dsize, list):
708
+ display_indices = dsize
709
+
710
+ if not display_indices:
711
+ print("Warning: no display indices (utils.get_display_indices)")
712
+
713
+ return display_indices
714
+
715
+
716
+ def get_latest_path(path: Union[str, Path]) -> Path:
717
+ """
718
+ Get the file/dir with largest increment i as `file (i).ext`
719
+
720
+ Args:
721
+ path (str or pathlib.Path): base pattern
722
+
723
+ Returns:
724
+ Path: path found
725
+ """
726
+ p = Path(path).resolve()
727
+ s = p.stem
728
+ e = p.suffix
729
+ files = list(p.parent.glob(f"{s}*(*){e}"))
730
+ indices = list(p.parent.glob(f"{s}*(*){e}"))
731
+ indices = list(map(lambda f: f.name, indices))
732
+ indices = list(map(lambda x: re.findall(r"\((.*?)\)", x)[-1], indices))
733
+ indices = list(map(int, indices))
734
+ if not indices:
735
+ f = p
736
+ else:
737
+ f = files[np.argmax(indices)]
738
+ return f
739
+
740
+
741
+ def get_existing_jobID(output_path: Path) -> str:
742
+ """
743
+ If the opts in output_path have a jobID, return it. Else, return None
744
+
745
+ Args:
746
+ output_path (pathlib.Path | str): where to look
747
+
748
+ Returns:
749
+ str | None: jobid
750
+ """
751
+ op = Path(output_path)
752
+ if not op.exists():
753
+ return
754
+
755
+ opts_path = get_latest_path(op / "opts.yaml")
756
+
757
+ if not opts_path.exists():
758
+ return
759
+
760
+ with opts_path.open("r") as f:
761
+ opts = yaml.safe_load(f)
762
+
763
+ jobID = opts.get("jobID", None)
764
+
765
+ return jobID
766
+
767
+
768
+ def find_existing_training(opts: Dict) -> Optional[Path]:
769
+ """
770
+ Looks in all directories like output_path.parent.glob(output_path.name*)
771
+ and compares the logged slurm job id with the current opts.jobID
772
+
773
+ If a match is found, the training should automatically continue in the
774
+ matching output directory
775
+
776
+ If no match is found, this is a new job and it should have a new output path
777
+
778
+ Args:
779
+ opts (Dict): trainer's options
780
+
781
+ Returns:
782
+ Optional[Path]: a path if a matchin jobID is found, None otherwise
783
+ """
784
+ if opts.jobID is None:
785
+ print("WARNING: current JOBID is None")
786
+ return
787
+
788
+ print("---------- Current job id:", opts.jobID)
789
+
790
+ path = Path(opts.output_path).resolve()
791
+ parent = path.parent
792
+ name = path.name
793
+
794
+ try:
795
+ similar_dirs = [p.resolve() for p in parent.glob(f"{name}*") if p.is_dir()]
796
+
797
+ for sd in similar_dirs:
798
+ candidate_jobID = get_existing_jobID(sd)
799
+ if candidate_jobID is not None and str(opts.jobID) == str(candidate_jobID):
800
+ print(f"Found matching job id in {sd}\n")
801
+ return sd
802
+ print("Did not find a matching job id in \n {}\n".format(str(similar_dirs)))
803
+ except Exception as e:
804
+ print("ERROR: Could not resume (find_existing_training)", e)
805
+
806
+
807
+ def pprint(*args: List[Any]):
808
+ """
809
+ Prints *args within a box of "=" characters
810
+ """
811
+ txt = " ".join(map(str, args))
812
+ col = "====="
813
+ space = " "
814
+ head_size = 2
815
+ header = "\n".join(["=" * (len(txt) + 2 * (len(col) + len(space)))] * head_size)
816
+ empty = "{}{}{}{}{}".format(col, space, " " * (len(txt)), space, col)
817
+ print()
818
+ print(header)
819
+ print(empty)
820
+ print("{}{}{}{}{}".format(col, space, txt, space, col))
821
+ print(empty)
822
+ print(header)
823
+ print()
824
+
825
+
826
+ def get_existing_comet_id(path: str) -> Optional[str]:
827
+ """
828
+ Returns the id of the existing comet experiment stored in path
829
+
830
+ Args:
831
+ path (str): Output pat where to look for the comet exp
832
+
833
+ Returns:
834
+ Optional[str]: comet exp's ID if any was found
835
+ """
836
+ comet_previous_path = get_latest_path(Path(path) / "comet_url.txt")
837
+ if comet_previous_path.exists():
838
+ with comet_previous_path.open("r") as f:
839
+ url = f.read().strip()
840
+ return comet_id_from_url(url)
841
+
842
+
843
+ def get_latest_opts(path):
844
+ """
845
+ get latest opts dumped in path if they look like *opts*.yaml
846
+ and were increased as
847
+ opts.yaml < opts (1).yaml < opts (2).yaml etc.
848
+
849
+ Args:
850
+ path (str or pathlib.Path): where to look for opts
851
+
852
+ Raises:
853
+ ValueError: If no match for *opts*.yaml is found
854
+
855
+ Returns:
856
+ addict.Dict: loaded opts
857
+ """
858
+ path = Path(path)
859
+ opts = get_latest_path(path / "opts.yaml")
860
+ assert opts.exists()
861
+ with opts.open("r") as f:
862
+ opts = Dict(yaml.safe_load(f))
863
+
864
+ events_path = Path(__file__).parent.parent / "shared" / "trainer" / "events.yaml"
865
+ if events_path.exists():
866
+ with events_path.open("r") as f:
867
+ events_dict = yaml.safe_load(f)
868
+ events_dict = Dict(events_dict)
869
+ opts.events = events_dict
870
+
871
+ return opts
872
+
873
+
874
+ def text_to_array(text, width=640, height=40):
875
+ """
876
+ Creates a numpy array of shape height x width x 3 with
877
+ text written on it using PIL
878
+
879
+ Args:
880
+ text (str): text to write
881
+ width (int, optional): Width of the resulting array. Defaults to 640.
882
+ height (int, optional): Height of the resulting array. Defaults to 40.
883
+
884
+ Returns:
885
+ np.ndarray: Centered text
886
+ """
887
+ from PIL import Image, ImageDraw, ImageFont
888
+
889
+ img = Image.new("RGB", (width, height), (255, 255, 255))
890
+ try:
891
+ font = ImageFont.truetype("UnBatang.ttf", 25)
892
+ except OSError:
893
+ font = ImageFont.load_default()
894
+
895
+ d = ImageDraw.Draw(img)
896
+ text_width, text_height = d.textsize(text)
897
+ h = 40 // 2 - 3 * text_height // 2
898
+ w = width // 2 - text_width
899
+ d.text((w, h), text, font=font, fill=(30, 30, 30))
900
+ return np.array(img)
901
+
902
+
903
+ def all_texts_to_array(texts, width=640, height=40):
904
+ """
905
+ Creates an array of texts, each of height and width specified
906
+ by the args, concatenated along their width dimension
907
+
908
+ Args:
909
+ texts (list(str)): List of texts to concatenate
910
+ width (int, optional): Individual text's width. Defaults to 640.
911
+ height (int, optional): Individual text's height. Defaults to 40.
912
+
913
+ Returns:
914
+ list: len(texts) text arrays with dims height x width x 3
915
+ """
916
+ return [text_to_array(text, width, height) for text in texts]
917
+
918
+
919
+ class Timer:
920
+ def __init__(self, name="", store=None, precision=3, ignore=False, cuda=True):
921
+ self.name = name
922
+ self.store = store
923
+ self.precision = precision
924
+ self.ignore = ignore
925
+ self.cuda = cuda
926
+
927
+ if cuda:
928
+ self._start_event = torch.cuda.Event(enable_timing=True)
929
+ self._end_event = torch.cuda.Event(enable_timing=True)
930
+
931
+ def format(self, n):
932
+ return f"{n:.{self.precision}f}"
933
+
934
+ def __enter__(self):
935
+ """Start a new timer as a context manager"""
936
+ if self.cuda:
937
+ self._start_event.record()
938
+ else:
939
+ self._start_time = time.perf_counter()
940
+ return self
941
+
942
+ def __exit__(self, *exc_info):
943
+ """Stop the context manager timer"""
944
+ if self.ignore:
945
+ return
946
+
947
+ if self.cuda:
948
+ self._end_event.record()
949
+ torch.cuda.synchronize()
950
+ new_time = self._start_event.elapsed_time(self._end_event) / 1000
951
+ else:
952
+ t = time.perf_counter()
953
+ new_time = t - self._start_time
954
+
955
+ if self.store is not None:
956
+ assert isinstance(self.store, list)
957
+ self.store.append(new_time)
958
+ if self.name:
959
+ print(f"[{self.name}] Elapsed time: {self.format(new_time)}")
960
+
961
+
962
+ def get_loader_output_shape_from_opts(opts):
963
+ transforms = opts.data.transforms
964
+
965
+ t = None
966
+ for t in transforms[::-1]:
967
+ if t.name == "resize":
968
+ break
969
+ assert t is not None
970
+
971
+ if isinstance(t.new_size, Dict):
972
+ return {
973
+ task: (
974
+ t.new_size.get(task, t.new_size.default),
975
+ t.new_size.get(task, t.new_size.default),
976
+ )
977
+ for task in opts.tasks + ["x"]
978
+ }
979
+ assert isinstance(t.new_size, int)
980
+ new_size = (t.new_size, t.new_size)
981
+ return {task: new_size for task in opts.tasks + ["x"]}
982
+
983
+
984
+ def find_target_size(opts, task):
985
+ target_size = None
986
+ if isinstance(opts.data.transforms[-1].new_size, int):
987
+ target_size = opts.data.transforms[-1].new_size
988
+ else:
989
+ if task in opts.data.transforms[-1].new_size:
990
+ target_size = opts.data.transforms[-1].new_size[task]
991
+ else:
992
+ assert "default" in opts.data.transforms[-1].new_size
993
+ target_size = opts.data.transforms[-1].new_size["default"]
994
+
995
+ return target_size
996
+
997
+
998
+ def to_128(im, w_target=-1):
999
+ h, w = im.shape[:2]
1000
+ aspect_ratio = h / w
1001
+ if w_target < 0:
1002
+ w_target = w
1003
+
1004
+ nw = int(w_target / 128) * 128
1005
+ nh = int(nw * aspect_ratio / 128) * 128
1006
+
1007
+ return nh, nw
1008
+
1009
+
1010
+ def is_image_file(filename):
1011
+ """Check that a file's name points to a known image format"""
1012
+ if isinstance(filename, Path):
1013
+ return filename.suffix in IMG_EXTENSIONS
1014
+
1015
+ return Path(filename).suffix in IMG_EXTENSIONS
1016
+
1017
+
1018
+ def find_images(path, recursive=False):
1019
+ """
1020
+ Get a list of all images contained in a directory:
1021
+
1022
+ - path.glob("*") if not recursive
1023
+ - path.glob("**/*") if recursive
1024
+ """
1025
+ p = Path(path)
1026
+ assert p.exists()
1027
+ assert p.is_dir()
1028
+ pattern = "*"
1029
+ if recursive:
1030
+ pattern += "*/*"
1031
+
1032
+ return [i for i in p.glob(pattern) if i.is_file() and is_image_file(i)]
1033
+
1034
+
1035
+ def cols():
1036
+ try:
1037
+ col = os.get_terminal_size().columns
1038
+ except Exception:
1039
+ col = 50
1040
+ return col
1041
+
1042
+
1043
+ def upload_images_to_exp(
1044
+ path, exp=None, project_name="climategan-eval", sleep=-1, verbose=0
1045
+ ):
1046
+ ims = find_images(path)
1047
+ end = None
1048
+ c = cols()
1049
+ if verbose == 1:
1050
+ end = "\r"
1051
+ if verbose > 1:
1052
+ end = "\n"
1053
+ if exp is None:
1054
+ exp = Experiment(project_name=project_name)
1055
+ for im in ims:
1056
+ exp.log_image(str(im))
1057
+ if verbose > 0:
1058
+ if verbose == 1:
1059
+ print(" " * (c - 1), end="\r", flush=True)
1060
+ print(str(im), end=end, flush=True)
1061
+ if sleep > 0:
1062
+ time.sleep(sleep)
1063
+ return exp
config/model/masker/.ipynb_checkpoints/opts-checkpoint.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8fd82a0d6c1de82a4ec0c1f70d0a7d3533b603a4d1ecf1c4a93d0e48aa94c31
3
+ size 6730
config/model/masker/opts.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8fd82a0d6c1de82a4ec0c1f70d0a7d3533b603a4d1ecf1c4a93d0e48aa94c31
3
+ size 6730
config/model/painter/opts.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:743b4cb46c6c62c424e348fec0093171cece006547deaec66f5324937bab4c13
3
+ size 5329
eval_masker.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compute metrics of the performance of the masker using a set of ground-truth labels
3
+
4
+ run eval_masker.py --model "/miniscratch/_groups/ccai/checkpoints/model/"
5
+
6
+ """
7
+ print("Imports...", end="")
8
+ import os
9
+ import os.path
10
+ from argparse import ArgumentParser
11
+ from pathlib import Path
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import pandas as pd
16
+ from comet_ml import Experiment
17
+ import torch
18
+ import yaml
19
+ from skimage.color import rgba2rgb
20
+ from skimage.io import imread, imsave
21
+ from skimage.transform import resize
22
+ from skimage.util import img_as_ubyte
23
+ from torchvision.transforms import ToTensor
24
+
25
+ from climategan.data import encode_mask_label
26
+ from climategan.eval_metrics import (
27
+ masker_classification_metrics,
28
+ get_confusion_matrix,
29
+ edges_coherence_std_min,
30
+ boxplot_metric,
31
+ clustermap_metric,
32
+ )
33
+ from climategan.transforms import PrepareTest
34
+ from climategan.trainer import Trainer
35
+ from climategan.utils import find_images
36
+
37
+ dict_metrics = {
38
+ "names": {
39
+ "tpr": "TPR, Recall, Sensitivity",
40
+ "tnr": "TNR, Specificity, Selectivity",
41
+ "fpr": "FPR",
42
+ "fpt": "False positives relative to image size",
43
+ "fnr": "FNR, Miss rate",
44
+ "fnt": "False negatives relative to image size",
45
+ "mpr": "May positive rate (MPR)",
46
+ "mnr": "May negative rate (MNR)",
47
+ "accuracy": "Accuracy (ignoring may)",
48
+ "error": "Error (ignoring may)",
49
+ "f05": "F0.05 score",
50
+ "precision": "Precision",
51
+ "edge_coherence": "Edge coherence",
52
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
53
+ },
54
+ "threshold": {
55
+ "tpr": 0.95,
56
+ "tnr": 0.95,
57
+ "fpr": 0.05,
58
+ "fpt": 0.01,
59
+ "fnr": 0.05,
60
+ "fnt": 0.01,
61
+ "accuracy": 0.95,
62
+ "error": 0.05,
63
+ "f05": 0.95,
64
+ "precision": 0.95,
65
+ "edge_coherence": 0.02,
66
+ "accuracy_must_may": 0.5,
67
+ },
68
+ "key_metrics": ["f05", "error", "edge_coherence", "mnr"],
69
+ }
70
+
71
+ print("Ok.")
72
+
73
+
74
+ def parsed_args():
75
+ """Parse and returns command-line args
76
+
77
+ Returns:
78
+ argparse.Namespace: the parsed arguments
79
+ """
80
+ parser = ArgumentParser()
81
+ parser.add_argument(
82
+ "--model",
83
+ type=str,
84
+ help="Path to a pre-trained model",
85
+ )
86
+ parser.add_argument(
87
+ "--images_dir",
88
+ default="/miniscratch/_groups/ccai/data/omnigan/masker-test-set/imgs",
89
+ type=str,
90
+ help="Directory containing the original test images",
91
+ )
92
+ parser.add_argument(
93
+ "--labels_dir",
94
+ default="/miniscratch/_groups/ccai/data/omnigan/masker-test-set/labels",
95
+ type=str,
96
+ help="Directory containing the labeled images",
97
+ )
98
+ parser.add_argument(
99
+ "--image_size",
100
+ default=640,
101
+ type=int,
102
+ help="The height and weight of the pre-processed images",
103
+ )
104
+ parser.add_argument(
105
+ "--max_files",
106
+ default=-1,
107
+ type=int,
108
+ help="Limit loaded samples",
109
+ )
110
+ parser.add_argument(
111
+ "--bin_value", default=0.5, type=float, help="Mask binarization threshold"
112
+ )
113
+ parser.add_argument(
114
+ "-y",
115
+ "--yaml",
116
+ default=None,
117
+ type=str,
118
+ help="load a yaml file to parametrize the evaluation",
119
+ )
120
+ parser.add_argument(
121
+ "-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str
122
+ )
123
+ parser.add_argument(
124
+ "-p",
125
+ "--plot",
126
+ action="store_true",
127
+ default=False,
128
+ help="Plot masker images & their metrics overlays",
129
+ )
130
+ parser.add_argument(
131
+ "--no_paint",
132
+ action="store_true",
133
+ default=False,
134
+ help="Do not log painted images",
135
+ )
136
+ parser.add_argument(
137
+ "--write_metrics",
138
+ action="store_true",
139
+ default=False,
140
+ help="If True, write CSV file and maps images in model's path directory",
141
+ )
142
+ parser.add_argument(
143
+ "--load_metrics",
144
+ action="store_true",
145
+ default=False,
146
+ help="If True, load predictions and metrics instead of re-computing",
147
+ )
148
+ parser.add_argument(
149
+ "--prepare_torch",
150
+ action="store_true",
151
+ default=False,
152
+ help="If True, pre-process images as torch tensors",
153
+ )
154
+ parser.add_argument(
155
+ "--output_csv",
156
+ default=None,
157
+ type=str,
158
+ help="Filename of the output CSV with the metrics of all models",
159
+ )
160
+
161
+ return parser.parse_args()
162
+
163
+
164
+ def uint8(array):
165
+ return array.astype(np.uint8)
166
+
167
+
168
+ def crop_and_resize(image_path, label_path):
169
+ """
170
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
171
+ is 640, then crops this resized image in its center so that the output is 640x640
172
+ without aspect ratio distortion
173
+
174
+ Args:
175
+ image_path (Path or str): Path to an image
176
+ label_path (Path or str): Path to the image's associated label
177
+
178
+ Returns:
179
+ tuple((np.ndarray, np.ndarray)): (new image, new label)
180
+ """
181
+
182
+ img = imread(image_path)
183
+ lab = imread(label_path)
184
+
185
+ # if img.shape[-1] == 4:
186
+ # img = uint8(rgba2rgb(img) * 255)
187
+
188
+ # TODO: remove (debug)
189
+ if img.shape[:2] != lab.shape[:2]:
190
+ print(
191
+ "\nWARNING: shape mismatch: im -> ({}) {}, lab -> ({}) {}".format(
192
+ img.shape[:2], image_path.name, lab.shape[:2], label_path.name
193
+ )
194
+ )
195
+ # breakpoint()
196
+
197
+ # resize keeping aspect ratio: smallest dim is 640
198
+ i_h, i_w = img.shape[:2]
199
+ if i_h < i_w:
200
+ i_size = (640, int(640 * i_w / i_h))
201
+ else:
202
+ i_size = (int(640 * i_h / i_w), 640)
203
+
204
+ l_h, l_w = img.shape[:2]
205
+ if l_h < l_w:
206
+ l_size = (640, int(640 * l_w / l_h))
207
+ else:
208
+ l_size = (int(640 * l_h / l_w), 640)
209
+
210
+ r_img = resize(img, i_size, preserve_range=True, anti_aliasing=True)
211
+ r_img = uint8(r_img)
212
+
213
+ r_lab = resize(lab, l_size, preserve_range=True, anti_aliasing=False, order=0)
214
+ r_lab = uint8(r_lab)
215
+
216
+ # crop in the center
217
+ H, W = r_img.shape[:2]
218
+
219
+ top = (H - 640) // 2
220
+ left = (W - 640) // 2
221
+
222
+ rc_img = r_img[top : top + 640, left : left + 640, :]
223
+ rc_lab = (
224
+ r_lab[top : top + 640, left : left + 640, :]
225
+ if r_lab.ndim == 3
226
+ else r_lab[top : top + 640, left : left + 640]
227
+ )
228
+
229
+ return rc_img, rc_lab
230
+
231
+
232
+ def plot_images(
233
+ output_filename,
234
+ img,
235
+ label,
236
+ pred,
237
+ metrics_dict,
238
+ maps_dict,
239
+ edge_coherence=-1,
240
+ pred_edge=None,
241
+ label_edge=None,
242
+ dpi=300,
243
+ alpha=0.5,
244
+ vmin=0.0,
245
+ vmax=1.0,
246
+ fontsize="xx-small",
247
+ cmap={
248
+ "fp": "Reds",
249
+ "fn": "Reds",
250
+ "may_neg": "Oranges",
251
+ "may_pos": "Purples",
252
+ "pred": "Greens",
253
+ },
254
+ ):
255
+ f, axes = plt.subplots(1, 5, dpi=dpi)
256
+
257
+ # FPR (predicted mask on cannot flood)
258
+ axes[0].imshow(img)
259
+ fp_map_plt = axes[0].imshow( # noqa: F841
260
+ maps_dict["fp"], vmin=vmin, vmax=vmax, cmap=cmap["fp"], alpha=alpha
261
+ )
262
+ axes[0].axis("off")
263
+ axes[0].set_title("FPR: {:.4f}".format(metrics_dict["fpr"]), fontsize=fontsize)
264
+
265
+ # FNR (missed mask on must flood)
266
+ axes[1].imshow(img)
267
+ fn_map_plt = axes[1].imshow( # noqa: F841
268
+ maps_dict["fn"], vmin=vmin, vmax=vmax, cmap=cmap["fn"], alpha=alpha
269
+ )
270
+ axes[1].axis("off")
271
+ axes[1].set_title("FNR: {:.4f}".format(metrics_dict["fnr"]), fontsize=fontsize)
272
+
273
+ # May flood
274
+ axes[2].imshow(img)
275
+ if edge_coherence != -1:
276
+ title = "MNR: {:.2f} | MPR: {:.2f}\nEdge coh.: {:.4f}".format(
277
+ metrics_dict["mnr"], metrics_dict["mpr"], edge_coherence
278
+ )
279
+ # alpha_here = alpha / 4.
280
+ # pred_edge_plt = axes[2].imshow(
281
+ # 1.0 - pred_edge, cmap="gray", alpha=alpha_here
282
+ # )
283
+ # label_edge_plt = axes[2].imshow(
284
+ # 1.0 - label_edge, cmap="gray", alpha=alpha_here
285
+ # )
286
+ else:
287
+ title = "MNR: {:.2f} | MPR: {:.2f}".format(mnr, mpr) # noqa: F821
288
+ # alpha_here = alpha / 2.
289
+ may_neg_map_plt = axes[2].imshow( # noqa: F841
290
+ maps_dict["may_neg"], vmin=vmin, vmax=vmax, cmap=cmap["may_neg"], alpha=alpha
291
+ )
292
+ may_pos_map_plt = axes[2].imshow( # noqa: F841
293
+ maps_dict["may_pos"], vmin=vmin, vmax=vmax, cmap=cmap["may_pos"], alpha=alpha
294
+ )
295
+ axes[2].set_title(title, fontsize=fontsize)
296
+ axes[2].axis("off")
297
+
298
+ # Prediction
299
+ axes[3].imshow(img)
300
+ pred_mask = axes[3].imshow( # noqa: F841
301
+ pred, vmin=vmin, vmax=vmax, cmap=cmap["pred"], alpha=alpha
302
+ )
303
+ axes[3].set_title("Predicted mask", fontsize=fontsize)
304
+ axes[3].axis("off")
305
+
306
+ # Labels
307
+ axes[4].imshow(img)
308
+ label_mask = axes[4].imshow(label, alpha=alpha) # noqa: F841
309
+ axes[4].set_title("Labels", fontsize=fontsize)
310
+ axes[4].axis("off")
311
+
312
+ f.savefig(
313
+ output_filename,
314
+ dpi=f.dpi,
315
+ bbox_inches="tight",
316
+ facecolor="white",
317
+ transparent=False,
318
+ )
319
+ plt.close(f)
320
+
321
+
322
+ def load_ground(ground_output_path, ref_image_path):
323
+ gop = Path(ground_output_path)
324
+ rip = Path(ref_image_path)
325
+
326
+ ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list(
327
+ (gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png")
328
+ )
329
+ if len(ground_paths) == 0:
330
+ raise ValueError(
331
+ f"Could not find a ground match in {str(gop)} for image {str(rip)}"
332
+ )
333
+ elif len(ground_paths) > 1:
334
+ raise ValueError(
335
+ f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:"
336
+ + f" {list(map(str, ground_paths))}"
337
+ )
338
+ ground_path = ground_paths[0]
339
+ _, ground = crop_and_resize(rip, ground_path)
340
+ if ground.ndim == 3:
341
+ ground = ground[:, :, 0]
342
+ ground = (ground > 0).astype(np.float32)
343
+ return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda()
344
+
345
+
346
+ def get_inferences(
347
+ image_arrays, model_path, image_paths, paint=False, bin_value=0.5, verbose=0
348
+ ):
349
+ """
350
+ Obtains the mask predictions of a model for a set of images
351
+
352
+ Parameters
353
+ ----------
354
+ image_arrays : array-like
355
+ A list of (1, CH, H, W) images
356
+
357
+ image_paths: list(Path)
358
+ A list of paths for images, in the same order as image_arrays
359
+
360
+ model_path : str
361
+ The path to a pre-trained model
362
+
363
+ Returns
364
+ -------
365
+ masks : list
366
+ A list of (H, W) predicted masks
367
+ """
368
+ device = torch.device("cuda:0")
369
+ torch.set_grad_enabled(False)
370
+ to_tensor = ToTensor()
371
+
372
+ is_ground = "ground" in Path(model_path).name
373
+ is_instagan = "instagan" in Path(model_path).name
374
+
375
+ if is_ground or is_instagan:
376
+ # we just care about he painter here
377
+ ground_path = model_path
378
+ model_path = (
379
+ "/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--38858350"
380
+ )
381
+
382
+ xs = [to_tensor(array).unsqueeze(0) for array in image_arrays]
383
+ xs = [x.to(torch.float32).to(device) for x in xs]
384
+ xs = [(x - 0.5) * 2 for x in xs]
385
+ trainer = Trainer.resume_from_path(
386
+ model_path, inference=True, new_exp=None, device=device
387
+ )
388
+ masks = []
389
+ painted = []
390
+ for idx, x in enumerate(xs):
391
+ if verbose > 0:
392
+ print(idx, "/", len(xs), end="\r")
393
+
394
+ if not is_ground and not is_instagan:
395
+ m = trainer.G.mask(x=x)
396
+ else:
397
+ m = load_ground(ground_path, image_paths[idx])
398
+
399
+ masks.append(m.squeeze().cpu())
400
+ if paint:
401
+ p = trainer.G.paint(m > bin_value, x)
402
+ painted.append(p.squeeze().cpu())
403
+ return masks, painted
404
+
405
+
406
+ if __name__ == "__main__":
407
+ # -----------------------------
408
+ # ----- Parse arguments -----
409
+ # -----------------------------
410
+ args = parsed_args()
411
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
412
+
413
+ # Determine output dir
414
+ try:
415
+ tmp_dir = Path(os.environ["SLURM_TMPDIR"])
416
+ except Exception as e:
417
+ print(e)
418
+ tmp_dir = Path(input("Enter tmp output directory: ")).resolve()
419
+
420
+ plot_dir = tmp_dir / "plots"
421
+ plot_dir.mkdir(parents=True, exist_ok=True)
422
+
423
+ # Build paths to data
424
+ imgs_paths = sorted(
425
+ find_images(args.images_dir, recursive=False), key=lambda x: x.name
426
+ )
427
+ labels_paths = sorted(
428
+ find_images(args.labels_dir, recursive=False),
429
+ key=lambda x: x.name.replace("_labeled.", "."),
430
+ )
431
+ if args.max_files > 0:
432
+ imgs_paths = imgs_paths[: args.max_files]
433
+ labels_paths = labels_paths[: args.max_files]
434
+
435
+ print(f"Loading {len(imgs_paths)} images and labels...")
436
+
437
+ # Pre-process images: resize + crop
438
+ # TODO: ? make cropping more flexible, not only central
439
+ if not args.prepare_torch:
440
+ ims_labs = [crop_and_resize(i, l) for i, l in zip(imgs_paths, labels_paths)]
441
+ imgs = [d[0] for d in ims_labs]
442
+ labels = [d[1] for d in ims_labs]
443
+ else:
444
+ prepare = PrepareTest()
445
+ imgs = prepare(imgs_paths, normalize=False, rescale=False)
446
+ labels = prepare(labels_paths, normalize=False, rescale=False)
447
+
448
+ imgs = [i.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for i in imgs]
449
+ labels = [
450
+ lab.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for lab in labels
451
+ ]
452
+ imgs = [rgba2rgb(img) if img.shape[-1] == 4 else img for img in imgs]
453
+ print(" Done.")
454
+
455
+ # Encode labels
456
+ print("Encode labels...", end="", flush=True)
457
+ # HW label
458
+ labels = [np.squeeze(encode_mask_label(label, "flood")) for label in labels]
459
+ print("Done.")
460
+
461
+ if args.yaml:
462
+ y_path = Path(args.yaml)
463
+ assert y_path.exists()
464
+ assert y_path.suffix in {".yaml", ".yml"}
465
+ with y_path.open("r") as f:
466
+ data = yaml.safe_load(f)
467
+ assert "models" in data
468
+
469
+ evaluations = [m for m in data["models"]]
470
+ else:
471
+ evaluations = [args.model]
472
+
473
+ for e, eval_path in enumerate(evaluations):
474
+ print("\n>>>>> Evaluation", e, ":", eval_path)
475
+ print("=" * 50)
476
+ print("=" * 50)
477
+
478
+ model_metrics_path = Path(eval_path) / "eval-metrics"
479
+ model_metrics_path.mkdir(exist_ok=True)
480
+ if args.load_metrics:
481
+ f_csv = model_metrics_path / "eval_masker.csv"
482
+ pred_out = model_metrics_path / "pred"
483
+ if f_csv.exists() and pred_out.exists():
484
+ print("Skipping model because pre-computed metrics exist")
485
+ continue
486
+
487
+ # Initialize New Comet Experiment
488
+ exp = Experiment(
489
+ project_name="climategan-masker-metrics", display_summary_level=0
490
+ )
491
+
492
+ # Obtain mask predictions
493
+ # TODO: remove (debug)
494
+ print("Obtain mask predictions", end="", flush=True)
495
+
496
+ preds, painted = get_inferences(
497
+ imgs,
498
+ eval_path,
499
+ imgs_paths,
500
+ paint=not args.no_paint,
501
+ bin_value=args.bin_value,
502
+ verbose=1,
503
+ )
504
+ preds = [pred.numpy() for pred in preds]
505
+ print(" Done.")
506
+
507
+ if args.bin_value > 0:
508
+ preds = [pred > args.bin_value for pred in preds]
509
+
510
+ # Compute metrics
511
+ df = pd.DataFrame(
512
+ columns=[
513
+ "tpr",
514
+ "tpt",
515
+ "tnr",
516
+ "tnt",
517
+ "fpr",
518
+ "fpt",
519
+ "fnr",
520
+ "fnt",
521
+ "mnr",
522
+ "mpr",
523
+ "accuracy",
524
+ "error",
525
+ "precision",
526
+ "f05",
527
+ "accuracy_must_may",
528
+ "edge_coherence",
529
+ "filename",
530
+ ]
531
+ )
532
+
533
+ print("Compute metrics and plot images")
534
+ for idx, (img, label, pred) in enumerate(zip(*(imgs, labels, preds))):
535
+ print(idx, "/", len(imgs), end="\r")
536
+
537
+ # Basic classification metrics
538
+ metrics_dict, maps_dict = masker_classification_metrics(
539
+ pred, label, labels_dict={"cannot": 0, "must": 1, "may": 2}
540
+ )
541
+
542
+ # Edges coherence
543
+ edge_coherence, pred_edge, label_edge = edges_coherence_std_min(pred, label)
544
+
545
+ series_dict = {
546
+ "tpr": metrics_dict["tpr"],
547
+ "tpt": metrics_dict["tpt"],
548
+ "tnr": metrics_dict["tnr"],
549
+ "tnt": metrics_dict["tnt"],
550
+ "fpr": metrics_dict["fpr"],
551
+ "fpt": metrics_dict["fpt"],
552
+ "fnr": metrics_dict["fnr"],
553
+ "fnt": metrics_dict["fnt"],
554
+ "mnr": metrics_dict["mnr"],
555
+ "mpr": metrics_dict["mpr"],
556
+ "accuracy": metrics_dict["accuracy"],
557
+ "error": metrics_dict["error"],
558
+ "precision": metrics_dict["precision"],
559
+ "f05": metrics_dict["f05"],
560
+ "accuracy_must_may": metrics_dict["accuracy_must_may"],
561
+ "edge_coherence": edge_coherence,
562
+ "filename": str(imgs_paths[idx].name),
563
+ }
564
+ df.loc[idx] = pd.Series(series_dict)
565
+
566
+ for k, v in series_dict.items():
567
+ if k == "filename":
568
+ continue
569
+ exp.log_metric(f"img_{k}", v, step=idx)
570
+
571
+ # Confusion matrix
572
+ confmat, _ = get_confusion_matrix(
573
+ metrics_dict["tpr"],
574
+ metrics_dict["tnr"],
575
+ metrics_dict["fpr"],
576
+ metrics_dict["fnr"],
577
+ metrics_dict["mnr"],
578
+ metrics_dict["mpr"],
579
+ )
580
+ confmat = np.around(confmat, decimals=3)
581
+ exp.log_confusion_matrix(
582
+ file_name=imgs_paths[idx].name + ".json",
583
+ title=imgs_paths[idx].name,
584
+ matrix=confmat,
585
+ labels=["Cannot", "Must", "May"],
586
+ row_label="Predicted",
587
+ column_label="Ground truth",
588
+ )
589
+
590
+ if args.plot:
591
+ # Plot prediction images
592
+ fig_filename = plot_dir / imgs_paths[idx].name
593
+ plot_images(
594
+ fig_filename,
595
+ img,
596
+ label,
597
+ pred,
598
+ metrics_dict,
599
+ maps_dict,
600
+ edge_coherence,
601
+ pred_edge,
602
+ label_edge,
603
+ )
604
+ exp.log_image(fig_filename)
605
+ if not args.no_paint:
606
+ masked = img * (1 - pred[..., None])
607
+ flooded = img_as_ubyte(
608
+ (painted[idx].permute(1, 2, 0).cpu().numpy() + 1) / 2
609
+ )
610
+ combined = np.concatenate([img, masked, flooded], 1)
611
+ exp.log_image(combined, imgs_paths[idx].name)
612
+
613
+ if args.write_metrics:
614
+ pred_out = model_metrics_path / "pred"
615
+ pred_out.mkdir(exist_ok=True)
616
+ imsave(
617
+ pred_out / f"{imgs_paths[idx].stem}_pred.png",
618
+ pred.astype(np.uint8),
619
+ )
620
+ for k, v in maps_dict.items():
621
+ metric_out = model_metrics_path / k
622
+ metric_out.mkdir(exist_ok=True)
623
+ imsave(
624
+ metric_out / f"{imgs_paths[idx].stem}_{k}.png",
625
+ v.astype(np.uint8),
626
+ )
627
+
628
+ # --------------------------------
629
+ # ----- END OF IMAGES LOOP -----
630
+ # --------------------------------
631
+
632
+ if args.write_metrics:
633
+ print(f"Writing metrics in {str(model_metrics_path)}")
634
+ f_csv = model_metrics_path / "eval_masker.csv"
635
+ df.to_csv(f_csv, index_label="idx")
636
+
637
+ print(" Done.")
638
+ # Summary statistics
639
+ means = df.mean(axis=0)
640
+ confmat_mean, confmat_std = get_confusion_matrix(
641
+ df.tpr, df.tnr, df.fpr, df.fnr, df.mpr, df.mnr
642
+ )
643
+ confmat_mean = np.around(confmat_mean, decimals=3)
644
+ confmat_std = np.around(confmat_std, decimals=3)
645
+
646
+ # Log to comet
647
+ exp.log_confusion_matrix(
648
+ file_name="confusion_matrix_mean.json",
649
+ title="confusion_matrix_mean.json",
650
+ matrix=confmat_mean,
651
+ labels=["Cannot", "Must", "May"],
652
+ row_label="Predicted",
653
+ column_label="Ground truth",
654
+ )
655
+ exp.log_confusion_matrix(
656
+ file_name="confusion_matrix_std.json",
657
+ title="confusion_matrix_std.json",
658
+ matrix=confmat_std,
659
+ labels=["Cannot", "Must", "May"],
660
+ row_label="Predicted",
661
+ column_label="Ground truth",
662
+ )
663
+ exp.log_metrics(dict(means))
664
+ exp.log_table("metrics.csv", df)
665
+ exp.log_html(df.to_html(col_space="80px"))
666
+ exp.log_parameters(vars(args))
667
+ exp.log_parameter("eval_path", str(eval_path))
668
+ exp.add_tag("eval_masker")
669
+ if args.tags:
670
+ exp.add_tags(args.tags)
671
+ exp.log_parameter("model_id", Path(eval_path).name)
672
+
673
+ # Close comet
674
+ exp.end()
675
+
676
+ # --------------------------------
677
+ # ----- END OF MODElS LOOP -----
678
+ # --------------------------------
679
+
680
+ # Compare models
681
+ if (args.load_metrics or args.write_metrics) and len(evaluations) > 1:
682
+ print(
683
+ "Plots for comparing the input models will be created and logged to comet"
684
+ )
685
+
686
+ # Initialize New Comet Experiment
687
+ exp = Experiment(
688
+ project_name="climategan-masker-metrics", display_summary_level=0
689
+ )
690
+ if args.tags:
691
+ exp.add_tags(args.tags)
692
+
693
+ # Build DataFrame with all models
694
+ print("Building pandas DataFrame...")
695
+ models_df = {}
696
+ for (m, model_path) in enumerate(evaluations):
697
+ model_path = Path(model_path)
698
+ with open(model_path / "opts.yaml", "r") as f:
699
+ opt = yaml.safe_load(f)
700
+ model_feats = ", ".join(
701
+ [
702
+ t
703
+ for t in sorted(opt["comet"]["tags"])
704
+ if "branch" not in t and "ablation" not in t and "trash" not in t
705
+ ]
706
+ )
707
+ model_id = f"{model_path.parent.name[-2:]}/{model_path.name}"
708
+ df_m = pd.read_csv(
709
+ model_path / "eval-metrics" / "eval_masker.csv", index_col=False
710
+ )
711
+ df_m["model"] = [model_id] * len(df_m)
712
+ df_m["model_idx"] = [m] * len(df_m)
713
+ df_m["model_feats"] = [model_feats] * len(df_m)
714
+ models_df.update({model_id: df_m})
715
+ df = pd.concat(list(models_df.values()), ignore_index=True)
716
+ df["model_img_idx"] = df.model.astype(str) + "-" + df.idx.astype(str)
717
+ df.rename(columns={"idx": "img_idx"}, inplace=True)
718
+ dict_models_labels = {
719
+ k: f"{v['model_idx'][0]}: {v['model_feats'][0]}"
720
+ for k, v in models_df.items()
721
+ }
722
+ print("Done")
723
+
724
+ if args.output_csv:
725
+ print(f"Writing DataFrame to {args.output_csv}")
726
+ df.to_csv(args.output_csv, index_label="model_img_idx")
727
+
728
+ # Determine images with low metrics in any model
729
+ print("Constructing filter based on metrics thresholds...")
730
+ idx_not_good_in_any = []
731
+ for idx in df.img_idx.unique():
732
+ df_th = df.loc[
733
+ (
734
+ # TODO: rethink thresholds
735
+ (df.tpr <= dict_metrics["threshold"]["tpr"])
736
+ | (df.fpr >= dict_metrics["threshold"]["fpr"])
737
+ | (df.edge_coherence >= dict_metrics["threshold"]["edge_coherence"])
738
+ )
739
+ & ((df.img_idx == idx) & (df.model.isin(df.model.unique())))
740
+ ]
741
+ if len(df_th) > 0:
742
+ idx_not_good_in_any.append(idx)
743
+ filters = {"all": df.img_idx.unique(), "not_good_in_any": idx_not_good_in_any}
744
+ print("Done")
745
+
746
+ # Boxplots of metrics
747
+ print("Plotting boxplots of metrics...")
748
+ for k, f in filters.items():
749
+ print(f"\tDistribution of [{k}] images...")
750
+ for metric in dict_metrics["names"].keys():
751
+ fig_filename = plot_dir / f"boxplot_{metric}_{k}.png"
752
+ if metric in ["mnr", "mpr", "accuracy_must_may"]:
753
+ boxplot_metric(
754
+ fig_filename,
755
+ df.loc[df.img_idx.isin(f)],
756
+ metric=metric,
757
+ dict_metrics=dict_metrics["names"],
758
+ do_stripplot=True,
759
+ dict_models=dict_models_labels,
760
+ order=list(df.model.unique()),
761
+ )
762
+ else:
763
+ boxplot_metric(
764
+ fig_filename,
765
+ df.loc[df.img_idx.isin(f)],
766
+ metric=metric,
767
+ dict_metrics=dict_metrics["names"],
768
+ dict_models=dict_models_labels,
769
+ fliersize=1.0,
770
+ order=list(df.model.unique()),
771
+ )
772
+ exp.log_image(fig_filename)
773
+ print("Done")
774
+
775
+ # Cluster Maps
776
+ print("Plotting clustermaps...")
777
+ for k, f in filters.items():
778
+ print(f"\tDistribution of [{k}] images...")
779
+ for metric in dict_metrics["names"].keys():
780
+ fig_filename = plot_dir / f"clustermap_{metric}_{k}.png"
781
+ df_mf = df.loc[df.img_idx.isin(f)].pivot("img_idx", "model", metric)
782
+ clustermap_metric(
783
+ output_filename=fig_filename,
784
+ df=df_mf,
785
+ metric=metric,
786
+ dict_metrics=dict_metrics["names"],
787
+ method="average",
788
+ cluster_metric="euclidean",
789
+ dict_models=dict_models_labels,
790
+ row_cluster=False,
791
+ )
792
+ exp.log_image(fig_filename)
793
+ print("Done")
794
+
795
+ # Close comet
796
+ exp.end()
figures/ablation_comparison.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script evaluates the contribution of a technique from the ablation study for
3
+ improving the masker evaluation metrics. The differences in the metrics are computed
4
+ for all images of paired models, that is those which only differ in the inclusion or
5
+ not of the given technique. Then, statistical inference is performed through the
6
+ percentile bootstrap to obtain robust estimates of the differences in the metrics and
7
+ confidence intervals. The script plots the distribution of the bootrstraped estimates.
8
+ """
9
+ print("Imports...", end="")
10
+ from argparse import ArgumentParser
11
+ import yaml
12
+ import numpy as np
13
+ import pandas as pd
14
+ import seaborn as sns
15
+ import os
16
+ from pathlib import Path
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib.patches as mpatches
19
+ import matplotlib.transforms as transforms
20
+
21
+
22
+ # -----------------------
23
+ # ----- Constants -----
24
+ # -----------------------
25
+
26
+ dict_models = {
27
+ "md": 11,
28
+ "dada_ms, msd, pseudo": 9,
29
+ "msd, pseudo": 4,
30
+ "dada, msd_spade, pseudo": 7,
31
+ "msd": 13,
32
+ "dada_m, msd": 17,
33
+ "dada, msd_spade": 16,
34
+ "msd_spade, pseudo": 5,
35
+ "dada_ms, msd": 18,
36
+ "dada, msd, pseudo": 6,
37
+ "ms": 12,
38
+ "dada, msd": 15,
39
+ "dada_m, msd, pseudo": 8,
40
+ "msd_spade": 14,
41
+ "m": 10,
42
+ "md, pseudo": 2,
43
+ "ms, pseudo": 3,
44
+ "m, pseudo": 1,
45
+ "ground": "G",
46
+ "instagan": "I",
47
+ }
48
+
49
+ dict_metrics = {
50
+ "names": {
51
+ "tpr": "TPR, Recall, Sensitivity",
52
+ "tnr": "TNR, Specificity, Selectivity",
53
+ "fpr": "FPR",
54
+ "fpt": "False positives relative to image size",
55
+ "fnr": "FNR, Miss rate",
56
+ "fnt": "False negatives relative to image size",
57
+ "mpr": "May positive rate (MPR)",
58
+ "mnr": "May negative rate (MNR)",
59
+ "accuracy": "Accuracy (ignoring may)",
60
+ "error": "Error",
61
+ "f05": "F05 score",
62
+ "precision": "Precision",
63
+ "edge_coherence": "Edge coherence",
64
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
65
+ },
66
+ "key_metrics": ["f05", "error", "edge_coherence"],
67
+ }
68
+ dict_techniques = {
69
+ "depth": "depth",
70
+ "segmentation": "seg",
71
+ "seg": "seg",
72
+ "dada_s": "dada_seg",
73
+ "dada_seg": "dada_seg",
74
+ "dada_segmentation": "dada_seg",
75
+ "dada_m": "dada_masker",
76
+ "dada_masker": "dada_masker",
77
+ "spade": "spade",
78
+ "pseudo": "pseudo",
79
+ "pseudo-labels": "pseudo",
80
+ "pseudo_labels": "pseudo",
81
+ }
82
+
83
+ # Markers
84
+ dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"}
85
+
86
+ # Model features
87
+ model_feats = [
88
+ "masker",
89
+ "seg",
90
+ "depth",
91
+ "dada_seg",
92
+ "dada_masker",
93
+ "spade",
94
+ "pseudo",
95
+ "ground",
96
+ "instagan",
97
+ ]
98
+
99
+ # Colors
100
+ palette_colorblind = sns.color_palette("colorblind")
101
+ color_climategan = palette_colorblind[0]
102
+ color_munit = palette_colorblind[1]
103
+ color_cyclegan = palette_colorblind[6]
104
+ color_instagan = palette_colorblind[8]
105
+ color_maskinstagan = palette_colorblind[2]
106
+ color_paintedground = palette_colorblind[3]
107
+
108
+ color_cat1 = palette_colorblind[0]
109
+ color_cat2 = palette_colorblind[1]
110
+ palette_lightest = [
111
+ sns.light_palette(color_cat1, n_colors=20)[3],
112
+ sns.light_palette(color_cat2, n_colors=20)[3],
113
+ ]
114
+ palette_light = [
115
+ sns.light_palette(color_cat1, n_colors=3)[1],
116
+ sns.light_palette(color_cat2, n_colors=3)[1],
117
+ ]
118
+ palette_medium = [color_cat1, color_cat2]
119
+ palette_dark = [
120
+ sns.dark_palette(color_cat1, n_colors=3)[1],
121
+ sns.dark_palette(color_cat2, n_colors=3)[1],
122
+ ]
123
+ palette_cat1 = [
124
+ palette_lightest[0],
125
+ palette_light[0],
126
+ palette_medium[0],
127
+ palette_dark[0],
128
+ ]
129
+ palette_cat2 = [
130
+ palette_lightest[1],
131
+ palette_light[1],
132
+ palette_medium[1],
133
+ palette_dark[1],
134
+ ]
135
+ color_cat1_light = palette_light[0]
136
+ color_cat2_light = palette_light[1]
137
+
138
+
139
+ def parsed_args():
140
+ """
141
+ Parse and returns command-line args
142
+
143
+ Returns:
144
+ argparse.Namespace: the parsed arguments
145
+ """
146
+ parser = ArgumentParser()
147
+ parser.add_argument(
148
+ "--input_csv",
149
+ default="ablations_metrics_20210311.csv",
150
+ type=str,
151
+ help="CSV containing the results of the ablation study",
152
+ )
153
+ parser.add_argument(
154
+ "--output_dir",
155
+ default=None,
156
+ type=str,
157
+ help="Output directory",
158
+ )
159
+ parser.add_argument(
160
+ "--models",
161
+ default="all",
162
+ type=str,
163
+ help="Models to display: all, pseudo, no_dada_masker, no_baseline",
164
+ )
165
+ parser.add_argument(
166
+ "--dpi",
167
+ default=200,
168
+ type=int,
169
+ help="DPI for the output images",
170
+ )
171
+ parser.add_argument(
172
+ "--n_bs",
173
+ default=1e6,
174
+ type=int,
175
+ help="Number of bootrstrap samples",
176
+ )
177
+ parser.add_argument(
178
+ "--alpha",
179
+ default=0.99,
180
+ type=float,
181
+ help="Confidence level",
182
+ )
183
+ parser.add_argument(
184
+ "--bs_seed",
185
+ default=17,
186
+ type=int,
187
+ help="Bootstrap random seed, for reproducibility",
188
+ )
189
+
190
+ return parser.parse_args()
191
+
192
+
193
+ def plot_median_metrics(
194
+ df, do_stripplot=True, dpi=200, bs_seed=37, n_bs=1000, **snskwargs
195
+ ):
196
+ def plot_metric(
197
+ ax, df, metric, do_stripplot=True, dpi=200, bs_seed=37, marker="o", **snskwargs
198
+ ):
199
+
200
+ y_labels = [dict_models[f] for f in df.model_feats.unique()]
201
+
202
+ # Labels
203
+ y_labels_int = np.sort([el for el in y_labels if isinstance(el, int)]).tolist()
204
+ y_order_int = [
205
+ k for vs in y_labels_int for k, vu in dict_models.items() if vs == vu
206
+ ]
207
+ y_labels_int = [str(el) for el in y_labels_int]
208
+
209
+ y_labels_str = sorted([el for el in y_labels if not isinstance(el, int)])
210
+ y_order_str = [
211
+ k for vs in y_labels_str for k, vu in dict_models.items() if vs == vu
212
+ ]
213
+ y_labels = y_labels_int + y_labels_str
214
+ y_order = y_order_int + y_order_str
215
+
216
+ # Palette
217
+ palette = len(y_labels_int) * [color_climategan]
218
+ for y in y_labels_str:
219
+ if y == "G":
220
+ palette = palette + [color_paintedground]
221
+ if y == "I":
222
+ palette = palette + [color_maskinstagan]
223
+
224
+ # Error
225
+ sns.pointplot(
226
+ ax=ax,
227
+ data=df,
228
+ x=metric,
229
+ y="model_feats",
230
+ order=y_order,
231
+ markers=marker,
232
+ estimator=np.median,
233
+ ci=99,
234
+ seed=bs_seed,
235
+ n_boot=n_bs,
236
+ join=False,
237
+ scale=0.6,
238
+ errwidth=1.5,
239
+ capsize=0.1,
240
+ palette=palette,
241
+ )
242
+ xlim = ax.get_xlim()
243
+
244
+ if do_stripplot:
245
+ sns.stripplot(
246
+ ax=ax,
247
+ data=df,
248
+ x=metric,
249
+ y="model_feats",
250
+ size=1.5,
251
+ palette=palette,
252
+ alpha=0.2,
253
+ )
254
+ ax.set_xlim(xlim)
255
+
256
+ # Set X-label
257
+ ax.set_xlabel(dict_metrics["names"][metric], rotation=0, fontsize="medium")
258
+
259
+ # Set Y-label
260
+ ax.set_ylabel(None)
261
+
262
+ ax.set_yticklabels(y_labels, fontsize="medium")
263
+
264
+ # Change spines
265
+ sns.despine(ax=ax, left=True, bottom=True)
266
+
267
+ # Draw gray area on final model
268
+ xlim = ax.get_xlim()
269
+ ylim = ax.get_ylim()
270
+ trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
271
+ rect = mpatches.Rectangle(
272
+ xy=(0.0, 5.5),
273
+ width=1,
274
+ height=1,
275
+ transform=trans,
276
+ linewidth=0.0,
277
+ edgecolor="none",
278
+ facecolor="gray",
279
+ alpha=0.05,
280
+ )
281
+ ax.add_patch(rect)
282
+
283
+ # Set up plot
284
+ sns.set(style="whitegrid")
285
+ plt.rcParams.update({"font.family": "serif"})
286
+ plt.rcParams.update(
287
+ {
288
+ "font.serif": [
289
+ "Computer Modern Roman",
290
+ "Times New Roman",
291
+ "Utopia",
292
+ "New Century Schoolbook",
293
+ "Century Schoolbook L",
294
+ "ITC Bookman",
295
+ "Bookman",
296
+ "Times",
297
+ "Palatino",
298
+ "Charter",
299
+ "serif" "Bitstream Vera Serif",
300
+ "DejaVu Serif",
301
+ ]
302
+ }
303
+ )
304
+
305
+ fig_h = 0.4 * len(df.model_feats.unique())
306
+ fig, axes = plt.subplots(
307
+ nrows=1, ncols=3, sharey=True, dpi=dpi, figsize=(18, fig_h)
308
+ )
309
+
310
+ # Error
311
+ plot_metric(
312
+ axes[0],
313
+ df,
314
+ "error",
315
+ do_stripplot=do_stripplot,
316
+ dpi=dpi,
317
+ bs_seed=bs_seed,
318
+ marker=dict_markers["error"],
319
+ )
320
+ axes[0].set_ylabel("Models")
321
+
322
+ # F05
323
+ plot_metric(
324
+ axes[1],
325
+ df,
326
+ "f05",
327
+ do_stripplot=do_stripplot,
328
+ dpi=dpi,
329
+ bs_seed=bs_seed,
330
+ marker=dict_markers["f05"],
331
+ )
332
+
333
+ # Edge coherence
334
+ plot_metric(
335
+ axes[2],
336
+ df,
337
+ "edge_coherence",
338
+ do_stripplot=do_stripplot,
339
+ dpi=dpi,
340
+ bs_seed=bs_seed,
341
+ marker=dict_markers["edge_coherence"],
342
+ )
343
+ xticks = axes[2].get_xticks()
344
+ xticklabels = ["{:.3f}".format(x) for x in xticks]
345
+ axes[2].set(xticks=xticks, xticklabels=xticklabels)
346
+
347
+ plt.subplots_adjust(wspace=0.12)
348
+
349
+ return fig
350
+
351
+
352
+ if __name__ == "__main__":
353
+ # -----------------------------
354
+ # ----- Parse arguments -----
355
+ # -----------------------------
356
+ args = parsed_args()
357
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
358
+
359
+ # Determine output dir
360
+ if args.output_dir is None:
361
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
362
+ else:
363
+ output_dir = Path(args.output_dir)
364
+ if not output_dir.exists():
365
+ output_dir.mkdir(parents=True, exist_ok=False)
366
+
367
+ # Store args
368
+ output_yml = output_dir / "ablation_comparison_{}.yml".format(args.models)
369
+ with open(output_yml, "w") as f:
370
+ yaml.dump(vars(args), f)
371
+
372
+ # Read CSV
373
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
374
+
375
+ # Determine models
376
+ if "all" in args.models.lower():
377
+ pass
378
+ else:
379
+ if "no_baseline" in args.models.lower():
380
+ df = df.loc[(df.ground == False) & (df.instagan == False)]
381
+ if "pseudo" in args.models.lower():
382
+ df = df.loc[
383
+ (df.pseudo == True) | (df.ground == True) | (df.instagan == True)
384
+ ]
385
+ if "no_dada_mask" in args.models.lower():
386
+ df = df.loc[
387
+ (df.dada_masker == False) | (df.ground == True) | (df.instagan == True)
388
+ ]
389
+
390
+ fig = plot_median_metrics(df, do_stripplot=True, dpi=args.dpi, bs_seed=args.bs_seed)
391
+
392
+ # Save figure
393
+ output_fig = output_dir / "ablation_comparison_{}.png".format(args.models)
394
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
figures/bootstrap_ablation.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script evaluates the contribution of a technique from the ablation study for
3
+ improving the masker evaluation metrics. The differences in the metrics are computed
4
+ for all images of paired models, that is those which only differ in the inclusion or
5
+ not of the given technique. Then, statistical inference is performed through the
6
+ percentile bootstrap to obtain robust estimates of the differences in the metrics and
7
+ confidence intervals. The script plots the distribution of the bootrstraped estimates.
8
+ """
9
+ print("Imports...", end="")
10
+ from argparse import ArgumentParser
11
+ import yaml
12
+ import os
13
+ import numpy as np
14
+ import pandas as pd
15
+ import seaborn as sns
16
+ from scipy.stats import trim_mean
17
+ from tqdm import tqdm
18
+ from pathlib import Path
19
+ import matplotlib.pyplot as plt
20
+ import matplotlib.patches as mpatches
21
+
22
+
23
+ # -----------------------
24
+ # ----- Constants -----
25
+ # -----------------------
26
+
27
+ dict_metrics = {
28
+ "names": {
29
+ "tpr": "TPR, Recall, Sensitivity",
30
+ "tnr": "TNR, Specificity, Selectivity",
31
+ "fpr": "FPR",
32
+ "fpt": "False positives relative to image size",
33
+ "fnr": "FNR, Miss rate",
34
+ "fnt": "False negatives relative to image size",
35
+ "mpr": "May positive rate (MPR)",
36
+ "mnr": "May negative rate (MNR)",
37
+ "accuracy": "Accuracy (ignoring may)",
38
+ "error": "Error",
39
+ "f05": "F05 score",
40
+ "precision": "Precision",
41
+ "edge_coherence": "Edge coherence",
42
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
43
+ },
44
+ "key_metrics": ["f05", "error", "edge_coherence"],
45
+ }
46
+ dict_techniques = {
47
+ "depth": "depth",
48
+ "segmentation": "seg",
49
+ "seg": "seg",
50
+ "dada_s": "dada_seg",
51
+ "dada_seg": "dada_seg",
52
+ "dada_segmentation": "dada_seg",
53
+ "dada_m": "dada_masker",
54
+ "dada_masker": "dada_masker",
55
+ "spade": "spade",
56
+ "pseudo": "pseudo",
57
+ "pseudo-labels": "pseudo",
58
+ "pseudo_labels": "pseudo",
59
+ }
60
+
61
+ # Model features
62
+ model_feats = [
63
+ "masker",
64
+ "seg",
65
+ "depth",
66
+ "dada_seg",
67
+ "dada_masker",
68
+ "spade",
69
+ "pseudo",
70
+ "ground",
71
+ "instagan",
72
+ ]
73
+
74
+ # Colors
75
+ palette_colorblind = sns.color_palette("colorblind")
76
+ color_cat1 = palette_colorblind[0]
77
+ color_cat2 = palette_colorblind[1]
78
+ palette_lightest = [
79
+ sns.light_palette(color_cat1, n_colors=20)[3],
80
+ sns.light_palette(color_cat2, n_colors=20)[3],
81
+ ]
82
+ palette_light = [
83
+ sns.light_palette(color_cat1, n_colors=3)[1],
84
+ sns.light_palette(color_cat2, n_colors=3)[1],
85
+ ]
86
+ palette_medium = [color_cat1, color_cat2]
87
+ palette_dark = [
88
+ sns.dark_palette(color_cat1, n_colors=3)[1],
89
+ sns.dark_palette(color_cat2, n_colors=3)[1],
90
+ ]
91
+ palette_cat1 = [
92
+ palette_lightest[0],
93
+ palette_light[0],
94
+ palette_medium[0],
95
+ palette_dark[0],
96
+ ]
97
+ palette_cat2 = [
98
+ palette_lightest[1],
99
+ palette_light[1],
100
+ palette_medium[1],
101
+ palette_dark[1],
102
+ ]
103
+ color_cat1_light = palette_light[0]
104
+ color_cat2_light = palette_light[1]
105
+
106
+
107
+ def parsed_args():
108
+ """
109
+ Parse and returns command-line args
110
+
111
+ Returns:
112
+ argparse.Namespace: the parsed arguments
113
+ """
114
+ parser = ArgumentParser()
115
+ parser.add_argument(
116
+ "--input_csv",
117
+ default="ablations_metrics_20210311.csv",
118
+ type=str,
119
+ help="CSV containing the results of the ablation study",
120
+ )
121
+ parser.add_argument(
122
+ "--output_dir",
123
+ default=None,
124
+ type=str,
125
+ help="Output directory",
126
+ )
127
+ parser.add_argument(
128
+ "--technique",
129
+ default=None,
130
+ type=str,
131
+ help="Keyword specifying the technique. One of: pseudo, depth, segmentation, dada_seg, dada_masker, spade",
132
+ )
133
+ parser.add_argument(
134
+ "--dpi",
135
+ default=200,
136
+ type=int,
137
+ help="DPI for the output images",
138
+ )
139
+ parser.add_argument(
140
+ "--n_bs",
141
+ default=1e6,
142
+ type=int,
143
+ help="Number of bootrstrap samples",
144
+ )
145
+ parser.add_argument(
146
+ "--alpha",
147
+ default=0.99,
148
+ type=float,
149
+ help="Confidence level",
150
+ )
151
+ parser.add_argument(
152
+ "--bs_seed",
153
+ default=17,
154
+ type=int,
155
+ help="Bootstrap random seed, for reproducibility",
156
+ )
157
+
158
+ return parser.parse_args()
159
+
160
+
161
+ def add_ci_mean(
162
+ ax, sample_measure, bs_mean, bs_std, ci, color, alpha, fontsize, invert=False
163
+ ):
164
+
165
+ # Fill area between CI
166
+ dist = ax.lines[0]
167
+ dist_y = dist.get_ydata()
168
+ dist_x = dist.get_xdata()
169
+ linewidth = dist.get_linewidth()
170
+
171
+ x_idx_low = np.argmin(np.abs(dist_x - ci[0]))
172
+ x_idx_high = np.argmin(np.abs(dist_x - ci[1]))
173
+ x_ci = dist_x[x_idx_low:x_idx_high]
174
+ y_ci = dist_y[x_idx_low:x_idx_high]
175
+
176
+ ax.fill_between(x_ci, 0, y_ci, facecolor=color, alpha=alpha)
177
+
178
+ # Add vertical lines of CI
179
+ ax.vlines(
180
+ x=ci[0],
181
+ ymin=0.0,
182
+ ymax=y_ci[0],
183
+ color=color,
184
+ linewidth=linewidth,
185
+ label="ci_low",
186
+ )
187
+ ax.vlines(
188
+ x=ci[1],
189
+ ymin=0.0,
190
+ ymax=y_ci[-1],
191
+ color=color,
192
+ linewidth=linewidth,
193
+ label="ci_high",
194
+ )
195
+
196
+ # Add annotations
197
+ bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
198
+
199
+ if invert:
200
+ ha_l = "right"
201
+ ha_u = "left"
202
+ else:
203
+ ha_l = "left"
204
+ ha_u = "right"
205
+ ax.text(
206
+ ci[0],
207
+ 0.0,
208
+ s="L = {:.4f}".format(ci[0]),
209
+ ha=ha_l,
210
+ va="bottom",
211
+ fontsize=fontsize,
212
+ bbox=bbox_props,
213
+ )
214
+ ax.text(
215
+ ci[1],
216
+ 0.0,
217
+ s="U = {:.4f}".format(ci[1]),
218
+ ha=ha_u,
219
+ va="bottom",
220
+ fontsize=fontsize,
221
+ bbox=bbox_props,
222
+ )
223
+
224
+ # Add vertical line of bootstrap mean
225
+ x_idx_mean = np.argmin(np.abs(dist_x - bs_mean))
226
+ ax.vlines(
227
+ x=bs_mean, ymin=0.0, ymax=dist_y[x_idx_mean], color="k", linewidth=linewidth
228
+ )
229
+
230
+ # Add annotation of bootstrap mean
231
+ bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
232
+
233
+ ax.text(
234
+ bs_mean,
235
+ 0.6 * dist_y[x_idx_mean],
236
+ s="Bootstrap mean = {:.4f}".format(bs_mean),
237
+ ha="center",
238
+ va="center",
239
+ fontsize=fontsize,
240
+ bbox=bbox_props,
241
+ )
242
+
243
+ # Add vertical line of sample_measure
244
+ x_idx_smeas = np.argmin(np.abs(dist_x - sample_measure))
245
+ ax.vlines(
246
+ x=sample_measure,
247
+ ymin=0.0,
248
+ ymax=dist_y[x_idx_smeas],
249
+ color="k",
250
+ linewidth=linewidth,
251
+ linestyles="dotted",
252
+ )
253
+
254
+ # Add SD
255
+ bbox_props = dict(boxstyle="darrow, pad=0.4", fc="w", ec="k", lw=2)
256
+
257
+ ax.text(
258
+ bs_mean,
259
+ 0.4 * dist_y[x_idx_mean],
260
+ s="SD = {:.4f} = SE".format(bs_std),
261
+ ha="center",
262
+ va="center",
263
+ fontsize=fontsize,
264
+ bbox=bbox_props,
265
+ )
266
+
267
+
268
+ def add_null_pval(ax, null, color, alpha, fontsize):
269
+
270
+ # Fill area between CI
271
+ dist = ax.lines[0]
272
+ dist_y = dist.get_ydata()
273
+ dist_x = dist.get_xdata()
274
+ linewidth = dist.get_linewidth()
275
+
276
+ x_idx_null = np.argmin(np.abs(dist_x - null))
277
+ if x_idx_null >= (len(dist_x) / 2.0):
278
+ x_pval = dist_x[x_idx_null:]
279
+ y_pval = dist_y[x_idx_null:]
280
+ else:
281
+ x_pval = dist_x[:x_idx_null]
282
+ y_pval = dist_y[:x_idx_null]
283
+
284
+ ax.fill_between(x_pval, 0, y_pval, facecolor=color, alpha=alpha)
285
+
286
+ # Add vertical lines of null
287
+ dist = ax.lines[0]
288
+ linewidth = dist.get_linewidth()
289
+ y_max = ax.get_ylim()[1]
290
+ ax.vlines(
291
+ x=null,
292
+ ymin=0.0,
293
+ ymax=y_max,
294
+ color="k",
295
+ linewidth=linewidth,
296
+ linestyles="dotted",
297
+ )
298
+
299
+ # Add annotations
300
+ bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
301
+
302
+ ax.text(
303
+ null,
304
+ 0.75 * y_max,
305
+ s="Null hypothesis = {:.1f}".format(null),
306
+ ha="center",
307
+ va="center",
308
+ fontsize=fontsize,
309
+ bbox=bbox_props,
310
+ )
311
+
312
+
313
+ def plot_bootstrap_distr(
314
+ sample_measure, bs_samples, alpha, color_ci, color_pval=None, null=None
315
+ ):
316
+
317
+ # Compute results from bootstrap
318
+ q_low = (1.0 - alpha) / 2.0
319
+ q_high = 1.0 - q_low
320
+ ci = np.quantile(bs_samples, [q_low, q_high])
321
+ bs_mean = np.mean(bs_samples)
322
+ bs_std = np.std(bs_samples)
323
+
324
+ if null is not None and color_pval is not None:
325
+ pval_flag = True
326
+ pval = np.min([[np.mean(bs_samples > null), np.mean(bs_samples < null)]]) * 2
327
+ else:
328
+ pval_flag = False
329
+
330
+ # Set up plot
331
+ sns.set(style="whitegrid")
332
+ fontsize = 24
333
+ font = {"family": "DejaVu Sans", "weight": "normal", "size": fontsize}
334
+ plt.rc("font", **font)
335
+ alpha_plot = 0.5
336
+
337
+ # Initialize the matplotlib figure
338
+ fig, ax = plt.subplots(figsize=(30, 12), dpi=args.dpi)
339
+
340
+ # Plot distribution of bootstrap means
341
+ sns.kdeplot(bs_samples, color="b", linewidth=5, gridsize=1000, ax=ax)
342
+
343
+ y_lim = ax.get_ylim()
344
+
345
+ # Change spines
346
+ sns.despine(left=True, bottom=True)
347
+
348
+ # Annotations
349
+ add_ci_mean(
350
+ ax,
351
+ sample_measure,
352
+ bs_mean,
353
+ bs_std,
354
+ ci,
355
+ color=color_ci,
356
+ alpha=alpha_plot,
357
+ fontsize=fontsize,
358
+ )
359
+
360
+ if pval_flag:
361
+ add_null_pval(ax, null, color=color_pval, alpha=alpha_plot, fontsize=fontsize)
362
+
363
+ # Legend
364
+ ci_patch = mpatches.Patch(
365
+ facecolor=color_ci,
366
+ edgecolor=None,
367
+ alpha=alpha_plot,
368
+ label="{:d} % confidence interval".format(int(100 * alpha)),
369
+ )
370
+
371
+ if pval_flag:
372
+ if pval == 0.0:
373
+ pval_patch = mpatches.Patch(
374
+ facecolor=color_pval,
375
+ edgecolor=None,
376
+ alpha=alpha_plot,
377
+ label="P value / 2 = {:.1f}".format(pval / 2.0),
378
+ )
379
+ elif np.around(pval / 2.0, decimals=4) > 0.0000:
380
+ pval_patch = mpatches.Patch(
381
+ facecolor=color_pval,
382
+ edgecolor=None,
383
+ alpha=alpha_plot,
384
+ label="P value / 2 = {:.4f}".format(pval / 2.0),
385
+ )
386
+ else:
387
+ pval_patch = mpatches.Patch(
388
+ facecolor=color_pval,
389
+ edgecolor=None,
390
+ alpha=alpha_plot,
391
+ label="P value / 2 < $10^{}$".format(np.ceil(np.log10(pval / 2.0))),
392
+ )
393
+
394
+ leg = ax.legend(
395
+ handles=[ci_patch, pval_patch],
396
+ ncol=1,
397
+ loc="upper right",
398
+ frameon=True,
399
+ framealpha=1.0,
400
+ title="",
401
+ fontsize=fontsize,
402
+ columnspacing=1.0,
403
+ labelspacing=0.2,
404
+ markerfirst=True,
405
+ )
406
+ else:
407
+ leg = ax.legend(
408
+ handles=[ci_patch],
409
+ ncol=1,
410
+ loc="upper right",
411
+ frameon=True,
412
+ framealpha=1.0,
413
+ title="",
414
+ fontsize=fontsize,
415
+ columnspacing=1.0,
416
+ labelspacing=0.2,
417
+ markerfirst=True,
418
+ )
419
+
420
+ plt.setp(leg.get_title(), fontsize=fontsize, horizontalalignment="left")
421
+
422
+ # Set X-label
423
+ ax.set_xlabel("Bootstrap estimates", rotation=0, fontsize=fontsize, labelpad=10.0)
424
+
425
+ # Set Y-label
426
+ ax.set_ylabel("Density", rotation=90, fontsize=fontsize, labelpad=10.0)
427
+
428
+ # Ticks
429
+ plt.setp(ax.get_xticklabels(), fontsize=0.8 * fontsize, verticalalignment="top")
430
+ plt.setp(ax.get_yticklabels(), fontsize=0.8 * fontsize)
431
+
432
+ ax.set_ylim(y_lim)
433
+
434
+ return fig, bs_mean, bs_std, ci, pval
435
+
436
+
437
+ if __name__ == "__main__":
438
+ # -----------------------------
439
+ # ----- Parse arguments -----
440
+ # -----------------------------
441
+ args = parsed_args()
442
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
443
+
444
+ # Determine output dir
445
+ if args.output_dir is None:
446
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
447
+ else:
448
+ output_dir = Path(args.output_dir)
449
+ if not output_dir.exists():
450
+ output_dir.mkdir(parents=True, exist_ok=False)
451
+
452
+ # Store args
453
+ output_yml = output_dir / "{}_bootstrap.yml".format(args.technique)
454
+ with open(output_yml, "w") as f:
455
+ yaml.dump(vars(args), f)
456
+
457
+ # Determine technique
458
+ if args.technique.lower() not in dict_techniques:
459
+ raise ValueError("{} is not a valid technique".format(args.technique))
460
+ else:
461
+ technique = dict_techniques[args.technique.lower()]
462
+
463
+ # Read CSV
464
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
465
+
466
+ # Find relevant model pairs
467
+ model_pairs = []
468
+ for mi in df.loc[df[technique]].model_feats.unique():
469
+ for mj in df.model_feats.unique():
470
+ if mj == mi:
471
+ continue
472
+
473
+ if df.loc[df.model_feats == mj, technique].unique()[0]:
474
+ continue
475
+
476
+ is_pair = True
477
+ for f in model_feats:
478
+ if f == technique:
479
+ continue
480
+ elif (
481
+ df.loc[df.model_feats == mj, f].unique()[0]
482
+ != df.loc[df.model_feats == mi, f].unique()[0]
483
+ ):
484
+ is_pair = False
485
+ break
486
+ else:
487
+ pass
488
+ if is_pair:
489
+ model_pairs.append((mi, mj))
490
+ break
491
+
492
+ print("\nModel pairs identified:\n")
493
+ for pair in model_pairs:
494
+ print("{} & {}".format(pair[0], pair[1]))
495
+
496
+ df["base"] = ["N/A"] * len(df)
497
+ for spp in model_pairs:
498
+ df.loc[df.model_feats.isin(spp), "depth_base"] = spp[1]
499
+
500
+ # Build bootstrap data
501
+ data = {m: [] for m in dict_metrics["key_metrics"]}
502
+ for m_with, m_without in model_pairs:
503
+ df_with = df.loc[df.model_feats == m_with]
504
+ df_without = df.loc[df.model_feats == m_without]
505
+ for metric in data.keys():
506
+ diff = (
507
+ df_with.sort_values(by="img_idx")[metric].values
508
+ - df_without.sort_values(by="img_idx")[metric].values
509
+ )
510
+ data[metric].extend(diff.tolist())
511
+
512
+ # Run bootstrap
513
+ measures = ["mean", "median", "20_trimmed_mean"]
514
+ bs_data = {meas: {m: np.zeros(args.n_bs) for m in data.keys()} for meas in measures}
515
+
516
+ np.random.seed(args.bs_seed)
517
+ for m, data_m in data.items():
518
+ for idx, s in enumerate(tqdm(range(args.n_bs))):
519
+ # Sample with replacement
520
+ bs_sample = np.random.choice(data_m, size=len(data_m), replace=True)
521
+
522
+ # Store mean
523
+ bs_data["mean"][m][idx] = np.mean(bs_sample)
524
+
525
+ # Store median
526
+ bs_data["median"][m][idx] = np.median(bs_sample)
527
+
528
+ # Store 20 % trimmed mean
529
+ bs_data["20_trimmed_mean"][m][idx] = trim_mean(bs_sample, 0.2)
530
+
531
+ for metric in dict_metrics["key_metrics"]:
532
+ sample_measure = trim_mean(data[metric], 0.2)
533
+ fig, bs_mean, bs_std, ci, pval = plot_bootstrap_distr(
534
+ sample_measure,
535
+ bs_data["20_trimmed_mean"][metric],
536
+ alpha=args.alpha,
537
+ color_ci=color_cat1_light,
538
+ color_pval=color_cat2_light,
539
+ null=0.0,
540
+ )
541
+
542
+ # Save figure
543
+ output_fig = output_dir / "{}_bootstrap_{}_{}.png".format(
544
+ args.technique, metric, "20_trimmed_mean"
545
+ )
546
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
547
+
548
+ # Store results
549
+ output_results = output_dir / "{}_bootstrap_{}_{}.yml".format(
550
+ args.technique, metric, "20_trimmed_mean"
551
+ )
552
+ results_dict = {
553
+ "measure": "20_trimmed_mean",
554
+ "sample_measure": float(sample_measure),
555
+ "bs_mean": float(bs_mean),
556
+ "bs_std": float(bs_std),
557
+ "ci_left": float(ci[0]),
558
+ "ci_right": float(ci[1]),
559
+ "pval": float(pval),
560
+ }
561
+ with open(output_results, "w") as f:
562
+ yaml.dump(results_dict, f)
figures/bootstrap_ablation_summary.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script computes the median difference and confidence intervals of all techniques from the ablation study for
3
+ improving the masker evaluation metrics. The differences in the metrics are computed
4
+ for all images of paired models, that is those which only differ in the inclusion or
5
+ not of the given technique. Then, statistical inference is performed through the
6
+ percentile bootstrap to obtain robust estimates of the differences in the metrics and
7
+ confidence intervals. The script plots the summary for all techniques.
8
+ """
9
+ print("Imports...", end="")
10
+ from argparse import ArgumentParser
11
+ import yaml
12
+ import numpy as np
13
+ import pandas as pd
14
+ import seaborn as sns
15
+ from scipy.special import comb
16
+ from scipy.stats import trim_mean
17
+ from tqdm import tqdm
18
+ from collections import OrderedDict
19
+ from pathlib import Path
20
+ import matplotlib.pyplot as plt
21
+ import matplotlib.patches as mpatches
22
+ import matplotlib.transforms as transforms
23
+
24
+
25
+ # -----------------------
26
+ # ----- Constants -----
27
+ # -----------------------
28
+
29
+ dict_metrics = {
30
+ "names": {
31
+ "tpr": "TPR, Recall, Sensitivity",
32
+ "tnr": "TNR, Specificity, Selectivity",
33
+ "fpr": "FPR",
34
+ "fpt": "False positives relative to image size",
35
+ "fnr": "FNR, Miss rate",
36
+ "fnt": "False negatives relative to image size",
37
+ "mpr": "May positive rate (MPR)",
38
+ "mnr": "May negative rate (MNR)",
39
+ "accuracy": "Accuracy (ignoring may)",
40
+ "error": "Error",
41
+ "f05": "F05 score",
42
+ "precision": "Precision",
43
+ "edge_coherence": "Edge coherence",
44
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
45
+ },
46
+ "key_metrics": ["error", "f05", "edge_coherence"],
47
+ }
48
+
49
+ dict_techniques = OrderedDict(
50
+ [
51
+ ("pseudo", "Pseudo labels"),
52
+ ("depth", "Depth (D)"),
53
+ ("seg", "Seg. (S)"),
54
+ ("spade", "SPADE"),
55
+ ("dada_seg", "DADA (S)"),
56
+ ("dada_masker", "DADA (M)"),
57
+ ]
58
+ )
59
+
60
+ # Model features
61
+ model_feats = [
62
+ "masker",
63
+ "seg",
64
+ "depth",
65
+ "dada_seg",
66
+ "dada_masker",
67
+ "spade",
68
+ "pseudo",
69
+ "ground",
70
+ "instagan",
71
+ ]
72
+
73
+ # Colors
74
+ crest = sns.color_palette("crest", as_cmap=False, n_colors=7)
75
+ palette_metrics = [crest[0], crest[3], crest[6]]
76
+ sns.palplot(palette_metrics)
77
+
78
+ # Markers
79
+ dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"}
80
+
81
+
82
+ def parsed_args():
83
+ """
84
+ Parse and returns command-line args
85
+
86
+ Returns:
87
+ argparse.Namespace: the parsed arguments
88
+ """
89
+ parser = ArgumentParser()
90
+ parser.add_argument(
91
+ "--input_csv",
92
+ default="ablations_metrics_20210311.csv",
93
+ type=str,
94
+ help="CSV containing the results of the ablation study",
95
+ )
96
+ parser.add_argument(
97
+ "--output_dir",
98
+ default=None,
99
+ type=str,
100
+ help="Output directory",
101
+ )
102
+ parser.add_argument(
103
+ "--dpi",
104
+ default=200,
105
+ type=int,
106
+ help="DPI for the output images",
107
+ )
108
+ parser.add_argument(
109
+ "--n_bs",
110
+ default=1e6,
111
+ type=int,
112
+ help="Number of bootrstrap samples",
113
+ )
114
+ parser.add_argument(
115
+ "--alpha",
116
+ default=0.99,
117
+ type=float,
118
+ help="Confidence level",
119
+ )
120
+ parser.add_argument(
121
+ "--bs_seed",
122
+ default=17,
123
+ type=int,
124
+ help="Bootstrap random seed, for reproducibility",
125
+ )
126
+
127
+ return parser.parse_args()
128
+
129
+
130
+ def trim_mean_wrapper(a):
131
+ return trim_mean(a, proportiontocut=0.2)
132
+
133
+
134
+ def find_model_pairs(technique, model_feats):
135
+ model_pairs = []
136
+ for mi in df.loc[df[technique]].model_feats.unique():
137
+ for mj in df.model_feats.unique():
138
+ if mj == mi:
139
+ continue
140
+
141
+ if df.loc[df.model_feats == mj, technique].unique()[0]:
142
+ continue
143
+
144
+ is_pair = True
145
+ for f in model_feats:
146
+ if f == technique:
147
+ continue
148
+ elif (
149
+ df.loc[df.model_feats == mj, f].unique()[0]
150
+ != df.loc[df.model_feats == mi, f].unique()[0]
151
+ ):
152
+ is_pair = False
153
+ break
154
+ else:
155
+ pass
156
+ if is_pair:
157
+ model_pairs.append((mi, mj))
158
+ break
159
+ return model_pairs
160
+
161
+
162
+ if __name__ == "__main__":
163
+ # -----------------------------
164
+ # ----- Parse arguments -----
165
+ # -----------------------------
166
+ args = parsed_args()
167
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
168
+
169
+ # Determine output dir
170
+ if args.output_dir is None:
171
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
172
+ else:
173
+ output_dir = Path(args.output_dir)
174
+ if not output_dir.exists():
175
+ output_dir.mkdir(parents=True, exist_ok=False)
176
+
177
+ # Store args
178
+ output_yml = output_dir / "bootstrap_summary.yml"
179
+ with open(output_yml, "w") as f:
180
+ yaml.dump(vars(args), f)
181
+
182
+ # Read CSV
183
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
184
+
185
+ # Build data set
186
+ dfbs = pd.DataFrame(columns=["diff", "technique", "metric"])
187
+ for technique in model_feats:
188
+
189
+ # Get pairs
190
+ model_pairs = find_model_pairs(technique, model_feats)
191
+
192
+ # Compute differences
193
+ for m_with, m_without in model_pairs:
194
+ df_with = df.loc[df.model_feats == m_with]
195
+ df_without = df.loc[df.model_feats == m_without]
196
+ for metric in dict_metrics["key_metrics"]:
197
+ diff = (
198
+ df_with.sort_values(by="img_idx")[metric].values
199
+ - df_without.sort_values(by="img_idx")[metric].values
200
+ )
201
+ dfm = pd.DataFrame.from_dict(
202
+ {"metric": metric, "technique": technique, "diff": diff}
203
+ )
204
+ dfbs = dfbs.append(dfm, ignore_index=True)
205
+
206
+ ### Plot
207
+
208
+ # Set up plot
209
+ sns.reset_orig()
210
+ sns.set(style="whitegrid")
211
+ plt.rcParams.update({"font.family": "serif"})
212
+ plt.rcParams.update(
213
+ {
214
+ "font.serif": [
215
+ "Computer Modern Roman",
216
+ "Times New Roman",
217
+ "Utopia",
218
+ "New Century Schoolbook",
219
+ "Century Schoolbook L",
220
+ "ITC Bookman",
221
+ "Bookman",
222
+ "Times",
223
+ "Palatino",
224
+ "Charter",
225
+ "serif" "Bitstream Vera Serif",
226
+ "DejaVu Serif",
227
+ ]
228
+ }
229
+ )
230
+
231
+ fig, axes = plt.subplots(
232
+ nrows=1, ncols=3, sharey=True, dpi=args.dpi, figsize=(9, 3)
233
+ )
234
+
235
+ metrics = ["error", "f05", "edge_coherence"]
236
+ dict_ci = {m: {} for m in metrics}
237
+
238
+ for idx, metric in enumerate(dict_metrics["key_metrics"]):
239
+
240
+ ax = sns.pointplot(
241
+ ax=axes[idx],
242
+ data=dfbs.loc[dfbs.metric.isin(["error", "f05", "edge_coherence"])],
243
+ order=dict_techniques.keys(),
244
+ x="diff",
245
+ y="technique",
246
+ hue="metric",
247
+ hue_order=[metric],
248
+ markers=dict_markers[metric],
249
+ palette=[palette_metrics[idx]],
250
+ errwidth=1.5,
251
+ scale=0.6,
252
+ join=False,
253
+ estimator=trim_mean_wrapper,
254
+ ci=int(args.alpha * 100),
255
+ n_boot=args.n_bs,
256
+ seed=args.bs_seed,
257
+ )
258
+
259
+ # Retrieve confidence intervals and update results dictionary
260
+ for line, technique in zip(ax.lines, dict_techniques.keys()):
261
+ dict_ci[metric].update(
262
+ {
263
+ technique: {
264
+ "20_trimmed_mean": float(
265
+ trim_mean_wrapper(
266
+ dfbs.loc[
267
+ (dfbs.technique == technique)
268
+ & (dfbs.metric == metrics[idx]),
269
+ "diff",
270
+ ].values
271
+ )
272
+ ),
273
+ "ci_left": float(line.get_xdata()[0]),
274
+ "ci_right": float(line.get_xdata()[1]),
275
+ }
276
+ }
277
+ )
278
+
279
+ leg_handles, leg_labels = ax.get_legend_handles_labels()
280
+
281
+ # Change spines
282
+ sns.despine(left=True, bottom=True)
283
+
284
+ # Set Y-label
285
+ ax.set_ylabel(None)
286
+
287
+ # Y-tick labels
288
+ ax.set_yticklabels(list(dict_techniques.values()), fontsize="medium")
289
+
290
+ # Set X-label
291
+ ax.set_xlabel(None)
292
+
293
+ # X-ticks
294
+ xticks = ax.get_xticks()
295
+ xticklabels = xticks
296
+ ax.set_xticks(xticks)
297
+ ax.set_xticklabels(xticklabels, fontsize="small")
298
+
299
+ # Y-lim
300
+ display2data = ax.transData.inverted()
301
+ ax2display = ax.transAxes
302
+ _, y_bottom = display2data.transform(ax.transAxes.transform((0.0, 0.02)))
303
+ _, y_top = display2data.transform(ax.transAxes.transform((0.0, 0.98)))
304
+ ax.set_ylim(bottom=y_bottom, top=y_top)
305
+
306
+ # Draw line at H0
307
+ y = np.arange(ax.get_ylim()[1], ax.get_ylim()[0], 0.1)
308
+ x = 0.0 * np.ones(y.shape[0])
309
+ ax.plot(x, y, linestyle=":", linewidth=1.5, color="black")
310
+
311
+ # Draw gray area
312
+ xlim = ax.get_xlim()
313
+ ylim = ax.get_ylim()
314
+ if metric == "error":
315
+ x0 = xlim[0]
316
+ width = np.abs(x0)
317
+ else:
318
+ x0 = 0.0
319
+ width = np.abs(xlim[1])
320
+ trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
321
+ rect = mpatches.Rectangle(
322
+ xy=(x0, 0.0),
323
+ width=width,
324
+ height=1,
325
+ transform=trans,
326
+ linewidth=0.0,
327
+ edgecolor="none",
328
+ facecolor="gray",
329
+ alpha=0.05,
330
+ )
331
+ ax.add_patch(rect)
332
+
333
+ # Legend
334
+ leg_handles, leg_labels = ax.get_legend_handles_labels()
335
+ leg_labels = [dict_metrics["names"][metric] for metric in leg_labels]
336
+ leg = ax.legend(
337
+ handles=leg_handles,
338
+ labels=leg_labels,
339
+ loc="center",
340
+ title="",
341
+ bbox_to_anchor=(-0.2, 1.05, 1.0, 0.0),
342
+ framealpha=1.0,
343
+ frameon=False,
344
+ handletextpad=-0.2,
345
+ )
346
+
347
+ # Set X-label (title) │
348
+ fig.suptitle(
349
+ "20 % trimmed mean difference and bootstrapped confidence intervals",
350
+ y=0.0,
351
+ fontsize="medium",
352
+ )
353
+
354
+ # Save figure
355
+ output_fig = output_dir / "bootstrap_summary.png"
356
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
357
+
358
+ # Store results
359
+ output_results = output_dir / "bootstrap_summary_results.yml"
360
+ with open(output_results, "w") as f:
361
+ yaml.dump(dict_ci, f)
figures/human_evaluation.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script plots the result of the human evaluation on Amazon Mechanical Turk, where
3
+ human participants chose between an image from ClimateGAN or from a different method.
4
+ """
5
+ print("Imports...", end="")
6
+ from argparse import ArgumentParser
7
+ import os
8
+ import yaml
9
+ import numpy as np
10
+ import pandas as pd
11
+ import seaborn as sns
12
+ from pathlib import Path
13
+ import matplotlib.pyplot as plt
14
+
15
+
16
+ # -----------------------
17
+ # ----- Constants -----
18
+ # -----------------------
19
+
20
+ comparables_dict = {
21
+ "munit_flooded": "MUNIT",
22
+ "cyclegan": "CycleGAN",
23
+ "instagan": "InstaGAN",
24
+ "instagan_copypaste": "Mask-InstaGAN",
25
+ "painted_ground": "Painted ground",
26
+ }
27
+
28
+
29
+ # Colors
30
+ palette_colorblind = sns.color_palette("colorblind")
31
+ color_climategan = palette_colorblind[9]
32
+
33
+ palette_colorblind = sns.color_palette("colorblind")
34
+ color_munit = palette_colorblind[1]
35
+ color_cyclegan = palette_colorblind[2]
36
+ color_instagan = palette_colorblind[3]
37
+ color_maskinstagan = palette_colorblind[6]
38
+ color_paintedground = palette_colorblind[8]
39
+ palette_comparables = [
40
+ color_munit,
41
+ color_cyclegan,
42
+ color_instagan,
43
+ color_maskinstagan,
44
+ color_paintedground,
45
+ ]
46
+ palette_comparables_light = [
47
+ sns.light_palette(color, n_colors=3)[1] for color in palette_comparables
48
+ ]
49
+
50
+
51
+ def parsed_args():
52
+ """
53
+ Parse and returns command-line args
54
+
55
+ Returns:
56
+ argparse.Namespace: the parsed arguments
57
+ """
58
+ parser = ArgumentParser()
59
+ parser.add_argument(
60
+ "--input_csv",
61
+ default="amt_omni-vs-other.csv",
62
+ type=str,
63
+ help="CSV containing the results of the human evaluation, pre-processed",
64
+ )
65
+ parser.add_argument(
66
+ "--output_dir",
67
+ default=None,
68
+ type=str,
69
+ help="Output directory",
70
+ )
71
+ parser.add_argument(
72
+ "--dpi",
73
+ default=200,
74
+ type=int,
75
+ help="DPI for the output images",
76
+ )
77
+ parser.add_argument(
78
+ "--n_bs",
79
+ default=1e6,
80
+ type=int,
81
+ help="Number of bootrstrap samples",
82
+ )
83
+ parser.add_argument(
84
+ "--bs_seed",
85
+ default=17,
86
+ type=int,
87
+ help="Bootstrap random seed, for reproducibility",
88
+ )
89
+
90
+ return parser.parse_args()
91
+
92
+
93
+ if __name__ == "__main__":
94
+ # -----------------------------
95
+ # ----- Parse arguments -----
96
+ # -----------------------------
97
+ args = parsed_args()
98
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
99
+
100
+ # Determine output dir
101
+ if args.output_dir is None:
102
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
103
+ else:
104
+ output_dir = Path(args.output_dir)
105
+ if not output_dir.exists():
106
+ output_dir.mkdir(parents=True, exist_ok=False)
107
+
108
+ # Store args
109
+ output_yml = output_dir / "args_human_evaluation.yml"
110
+ with open(output_yml, "w") as f:
111
+ yaml.dump(vars(args), f)
112
+
113
+ # Read CSV
114
+ df = pd.read_csv(args.input_csv)
115
+
116
+ # Sort Y labels
117
+ comparables = df.comparable.unique()
118
+ is_climategan_sum = [
119
+ df.loc[df.comparable == c, "climategan"].sum() for c in comparables
120
+ ]
121
+ comparables = comparables[np.argsort(is_climategan_sum)[::-1]]
122
+
123
+ # Plot setup
124
+ sns.set(style="whitegrid")
125
+ plt.rcParams.update({"font.family": "serif"})
126
+ plt.rcParams.update(
127
+ {
128
+ "font.serif": [
129
+ "Computer Modern Roman",
130
+ "Times New Roman",
131
+ "Utopia",
132
+ "New Century Schoolbook",
133
+ "Century Schoolbook L",
134
+ "ITC Bookman",
135
+ "Bookman",
136
+ "Times",
137
+ "Palatino",
138
+ "Charter",
139
+ "serif" "Bitstream Vera Serif",
140
+ "DejaVu Serif",
141
+ ]
142
+ }
143
+ )
144
+ fontsize = "medium"
145
+
146
+ # Initialize the matplotlib figure
147
+ fig, ax = plt.subplots(figsize=(10.5, 3), dpi=args.dpi)
148
+
149
+ # Plot the total (right)
150
+ sns.barplot(
151
+ data=df.loc[df.is_valid],
152
+ x="is_valid",
153
+ y="comparable",
154
+ order=comparables,
155
+ orient="h",
156
+ label="comparable",
157
+ palette=palette_comparables_light,
158
+ ci=None,
159
+ )
160
+
161
+ # Plot the left
162
+ sns.barplot(
163
+ data=df.loc[df.is_valid],
164
+ x="climategan",
165
+ y="comparable",
166
+ order=comparables,
167
+ orient="h",
168
+ label="climategan",
169
+ color=color_climategan,
170
+ ci=99,
171
+ n_boot=args.n_bs,
172
+ seed=args.bs_seed,
173
+ errcolor="black",
174
+ errwidth=1.5,
175
+ capsize=0.1,
176
+ )
177
+
178
+ # Draw line at 0.5
179
+ y = np.arange(ax.get_ylim()[1] + 0.1, ax.get_ylim()[0], 0.1)
180
+ x = 0.5 * np.ones(y.shape[0])
181
+ ax.plot(x, y, linestyle=":", linewidth=1.5, color="black")
182
+
183
+ # Change Y-Tick labels
184
+ yticklabels = [comparables_dict[ytick.get_text()] for ytick in ax.get_yticklabels()]
185
+ yticklabels_text = ax.set_yticklabels(
186
+ yticklabels, fontsize=fontsize, horizontalalignment="right", x=0.96
187
+ )
188
+ for ytl in yticklabels_text:
189
+ ax.add_artist(ytl)
190
+
191
+ # Remove Y-label
192
+ ax.set_ylabel(ylabel="")
193
+
194
+ # Change X-Tick labels
195
+ xlim = [0.0, 1.1]
196
+ xticks = np.arange(xlim[0], xlim[1], 0.1)
197
+ ax.set(xticks=xticks)
198
+ plt.setp(ax.get_xticklabels(), fontsize=fontsize)
199
+
200
+ # Set X-label
201
+ ax.set_xlabel(None)
202
+
203
+ # Change spines
204
+ sns.despine(left=True, bottom=True)
205
+
206
+ # Save figure
207
+ output_fig = output_dir / "human_evaluation_rate_climategan.png"
208
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
figures/labels.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This scripts plots images from the Masker test set overlaid with their labels.
3
+ """
4
+ print("Imports...", end="")
5
+ from argparse import ArgumentParser
6
+ import os
7
+ import yaml
8
+ import numpy as np
9
+ import pandas as pd
10
+ import seaborn as sns
11
+ from pathlib import Path
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as mpatches
14
+
15
+ import sys
16
+
17
+ sys.path.append("../")
18
+
19
+ from eval_masker import crop_and_resize
20
+
21
+
22
+ # -----------------------
23
+ # ----- Constants -----
24
+ # -----------------------
25
+
26
+ # Colors
27
+ colorblind_palette = sns.color_palette("colorblind")
28
+ color_cannot = colorblind_palette[1]
29
+ color_must = colorblind_palette[2]
30
+ color_may = colorblind_palette[7]
31
+ color_pred = colorblind_palette[4]
32
+
33
+ icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
34
+ color_tp = icefire[0]
35
+ color_tn = icefire[1]
36
+ color_fp = icefire[4]
37
+ color_fn = icefire[3]
38
+
39
+
40
+ def parsed_args():
41
+ """
42
+ Parse and returns command-line args
43
+
44
+ Returns:
45
+ argparse.Namespace: the parsed arguments
46
+ """
47
+ parser = ArgumentParser()
48
+ parser.add_argument(
49
+ "--input_csv",
50
+ default="ablations_metrics_20210311.csv",
51
+ type=str,
52
+ help="CSV containing the results of the ablation study",
53
+ )
54
+ parser.add_argument(
55
+ "--output_dir",
56
+ default=None,
57
+ type=str,
58
+ help="Output directory",
59
+ )
60
+ parser.add_argument(
61
+ "--masker_test_set_dir",
62
+ default=None,
63
+ type=str,
64
+ help="Directory containing the test images",
65
+ )
66
+ parser.add_argument(
67
+ "--images",
68
+ nargs="+",
69
+ help="List of image file names to plot",
70
+ default=[],
71
+ type=str,
72
+ )
73
+ parser.add_argument(
74
+ "--dpi",
75
+ default=200,
76
+ type=int,
77
+ help="DPI for the output images",
78
+ )
79
+ parser.add_argument(
80
+ "--alpha",
81
+ default=0.5,
82
+ type=float,
83
+ help="Transparency of labels shade",
84
+ )
85
+
86
+ return parser.parse_args()
87
+
88
+
89
+ def map_color(arr, input_color, output_color, rtol=1e-09):
90
+ """
91
+ Maps one color to another
92
+ """
93
+ input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
94
+ output = arr.copy()
95
+ output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
96
+ return output
97
+
98
+
99
+ if __name__ == "__main__":
100
+ # -----------------------------
101
+ # ----- Parse arguments -----
102
+ # -----------------------------
103
+ args = parsed_args()
104
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
105
+
106
+ # Determine output dir
107
+ if args.output_dir is None:
108
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
109
+ else:
110
+ output_dir = Path(args.output_dir)
111
+ if not output_dir.exists():
112
+ output_dir.mkdir(parents=True, exist_ok=False)
113
+
114
+ # Store args
115
+ output_yml = output_dir / "labels.yml"
116
+ with open(output_yml, "w") as f:
117
+ yaml.dump(vars(args), f)
118
+
119
+ # Data dirs
120
+ imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
121
+ labels_path = Path(args.masker_test_set_dir) / "labels"
122
+
123
+ # Read CSV
124
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
125
+
126
+ # Set up plot
127
+ sns.reset_orig()
128
+ sns.set(style="whitegrid")
129
+ plt.rcParams.update({"font.family": "serif"})
130
+ plt.rcParams.update(
131
+ {
132
+ "font.serif": [
133
+ "Computer Modern Roman",
134
+ "Times New Roman",
135
+ "Utopia",
136
+ "New Century Schoolbook",
137
+ "Century Schoolbook L",
138
+ "ITC Bookman",
139
+ "Bookman",
140
+ "Times",
141
+ "Palatino",
142
+ "Charter",
143
+ "serif" "Bitstream Vera Serif",
144
+ "DejaVu Serif",
145
+ ]
146
+ }
147
+ )
148
+
149
+ fig, axes = plt.subplots(
150
+ nrows=1, ncols=len(args.images), dpi=args.dpi, figsize=(len(args.images) * 5, 5)
151
+ )
152
+
153
+ for idx, img_filename in enumerate(args.images):
154
+
155
+ # Read images
156
+ img_path = imgs_orig_path / img_filename
157
+ label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
158
+ img, label = crop_and_resize(img_path, label_path)
159
+
160
+ # Map label colors
161
+ label_colmap = label.astype(float)
162
+ label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
163
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
164
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
165
+
166
+ ax = axes[idx]
167
+ ax.imshow(img)
168
+ ax.imshow(label_colmap, alpha=args.alpha)
169
+ ax.axis("off")
170
+
171
+ # Legend
172
+ handles = []
173
+ lw = 1.0
174
+ handles.append(
175
+ mpatches.Patch(
176
+ facecolor=color_must, label="must", linewidth=lw, alpha=args.alpha
177
+ )
178
+ )
179
+ handles.append(
180
+ mpatches.Patch(facecolor=color_may, label="may", linewidth=lw, alpha=args.alpha)
181
+ )
182
+ handles.append(
183
+ mpatches.Patch(
184
+ facecolor=color_cannot, label="cannot", linewidth=lw, alpha=args.alpha
185
+ )
186
+ )
187
+ labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"]
188
+ fig.legend(
189
+ handles=handles,
190
+ labels=labels,
191
+ loc="upper center",
192
+ bbox_to_anchor=(0.0, 0.85, 1.0, 0.075),
193
+ ncol=len(args.images),
194
+ fontsize="medium",
195
+ frameon=False,
196
+ )
197
+
198
+ # Save figure
199
+ output_fig = output_dir / "labels.png"
200
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
figures/metrics.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This scripts plots examples of the images that get best and worse metrics
3
+ """
4
+ print("Imports...", end="")
5
+ import os
6
+ import sys
7
+ from argparse import ArgumentParser
8
+ from pathlib import Path
9
+
10
+ import matplotlib.patches as mpatches
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ import seaborn as sns
15
+ import yaml
16
+ from imageio import imread
17
+ from skimage.color import rgba2rgb
18
+ from sklearn.metrics.pairwise import euclidean_distances
19
+
20
+ sys.path.append("../")
21
+
22
+ from climategan.data import encode_mask_label
23
+ from climategan.eval_metrics import edges_coherence_std_min
24
+ from eval_masker import crop_and_resize
25
+
26
+ # -----------------------
27
+ # ----- Constants -----
28
+ # -----------------------
29
+
30
+ # Metrics
31
+ metrics = ["error", "f05", "edge_coherence"]
32
+
33
+ dict_metrics = {
34
+ "names": {
35
+ "tpr": "TPR, Recall, Sensitivity",
36
+ "tnr": "TNR, Specificity, Selectivity",
37
+ "fpr": "FPR",
38
+ "fpt": "False positives relative to image size",
39
+ "fnr": "FNR, Miss rate",
40
+ "fnt": "False negatives relative to image size",
41
+ "mpr": "May positive rate (MPR)",
42
+ "mnr": "May negative rate (MNR)",
43
+ "accuracy": "Accuracy (ignoring may)",
44
+ "error": "Error",
45
+ "f05": "F05 score",
46
+ "precision": "Precision",
47
+ "edge_coherence": "Edge coherence",
48
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
49
+ },
50
+ "key_metrics": ["error", "f05", "edge_coherence"],
51
+ }
52
+
53
+
54
+ # Colors
55
+ colorblind_palette = sns.color_palette("colorblind")
56
+ color_cannot = colorblind_palette[1]
57
+ color_must = colorblind_palette[2]
58
+ color_may = colorblind_palette[7]
59
+ color_pred = colorblind_palette[4]
60
+
61
+ icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
62
+ color_tp = icefire[0]
63
+ color_tn = icefire[1]
64
+ color_fp = icefire[4]
65
+ color_fn = icefire[3]
66
+
67
+
68
+ def parsed_args():
69
+ """
70
+ Parse and returns command-line args
71
+
72
+ Returns:
73
+ argparse.Namespace: the parsed arguments
74
+ """
75
+ parser = ArgumentParser()
76
+ parser.add_argument(
77
+ "--input_csv",
78
+ default="ablations_metrics_20210311.csv",
79
+ type=str,
80
+ help="CSV containing the results of the ablation study",
81
+ )
82
+ parser.add_argument(
83
+ "--output_dir",
84
+ default=None,
85
+ type=str,
86
+ help="Output directory",
87
+ )
88
+ parser.add_argument(
89
+ "--models_log_path",
90
+ default=None,
91
+ type=str,
92
+ help="Path containing the log files of the models",
93
+ )
94
+ parser.add_argument(
95
+ "--masker_test_set_dir",
96
+ default=None,
97
+ type=str,
98
+ help="Directory containing the test images",
99
+ )
100
+ parser.add_argument(
101
+ "--best_model",
102
+ default="dada, msd_spade, pseudo",
103
+ type=str,
104
+ help="The string identifier of the best model",
105
+ )
106
+ parser.add_argument(
107
+ "--dpi",
108
+ default=200,
109
+ type=int,
110
+ help="DPI for the output images",
111
+ )
112
+ parser.add_argument(
113
+ "--alpha",
114
+ default=0.5,
115
+ type=float,
116
+ help="Transparency of labels shade",
117
+ )
118
+ parser.add_argument(
119
+ "--percentile",
120
+ default=0.05,
121
+ type=float,
122
+ help="Transparency of labels shade",
123
+ )
124
+ parser.add_argument(
125
+ "--seed",
126
+ default=None,
127
+ type=int,
128
+ help="Bootstrap random seed, for reproducibility",
129
+ )
130
+ parser.add_argument(
131
+ "--no_images",
132
+ action="store_true",
133
+ default=False,
134
+ help="Do not generate images",
135
+ )
136
+
137
+ return parser.parse_args()
138
+
139
+
140
+ def map_color(arr, input_color, output_color, rtol=1e-09):
141
+ """
142
+ Maps one color to another
143
+ """
144
+ input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
145
+ output = arr.copy()
146
+ output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
147
+ return output
148
+
149
+
150
+ def plot_labels(ax, img, label, img_id, do_legend):
151
+ label_colmap = label.astype(float)
152
+ label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
153
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
154
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
155
+
156
+ ax.imshow(img)
157
+ ax.imshow(label_colmap, alpha=0.5)
158
+ ax.axis("off")
159
+
160
+ # Annotation
161
+ ax.annotate(
162
+ xy=(0.05, 0.95),
163
+ xycoords="axes fraction",
164
+ xytext=(0.05, 0.95),
165
+ textcoords="axes fraction",
166
+ text=img_id,
167
+ fontsize="x-large",
168
+ verticalalignment="top",
169
+ color="white",
170
+ )
171
+
172
+ # Legend
173
+ if do_legend:
174
+ handles = []
175
+ lw = 1.0
176
+ handles.append(
177
+ mpatches.Patch(facecolor=color_must, label="must", linewidth=lw, alpha=0.66)
178
+ )
179
+ handles.append(
180
+ mpatches.Patch(facecolor=color_may, label="must", linewidth=lw, alpha=0.66)
181
+ )
182
+ handles.append(
183
+ mpatches.Patch(
184
+ facecolor=color_cannot, label="must", linewidth=lw, alpha=0.66
185
+ )
186
+ )
187
+ labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"]
188
+ ax.legend(
189
+ handles=handles,
190
+ labels=labels,
191
+ bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
192
+ ncol=3,
193
+ mode="expand",
194
+ fontsize="xx-small",
195
+ frameon=False,
196
+ )
197
+
198
+
199
+ def plot_pred(ax, img, pred, img_id, do_legend):
200
+ pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
201
+
202
+ pred_colmap = pred.astype(float)
203
+ pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
204
+ pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
205
+ pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma
206
+
207
+ ax.imshow(img)
208
+ ax.imshow(pred_colmap_ma, alpha=0.5)
209
+ ax.axis("off")
210
+
211
+ # Annotation
212
+ ax.annotate(
213
+ xy=(0.05, 0.95),
214
+ xycoords="axes fraction",
215
+ xytext=(0.05, 0.95),
216
+ textcoords="axes fraction",
217
+ text=img_id,
218
+ fontsize="x-large",
219
+ verticalalignment="top",
220
+ color="white",
221
+ )
222
+
223
+ # Legend
224
+ if do_legend:
225
+ handles = []
226
+ lw = 1.0
227
+ handles.append(
228
+ mpatches.Patch(facecolor=color_pred, label="must", linewidth=lw, alpha=0.66)
229
+ )
230
+ labels = ["Prediction"]
231
+ ax.legend(
232
+ handles=handles,
233
+ labels=labels,
234
+ bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
235
+ ncol=3,
236
+ mode="expand",
237
+ fontsize="xx-small",
238
+ frameon=False,
239
+ )
240
+
241
+
242
+ def plot_correct_incorrect(ax, img_filename, img, label, img_id, do_legend):
243
+ # FP
244
+ fp_map = imread(
245
+ model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem)
246
+ )
247
+ fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3))
248
+
249
+ fp_map_colmap = fp_map.astype(float)
250
+ fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp)
251
+
252
+ # FN
253
+ fn_map = imread(
254
+ model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem)
255
+ )
256
+ fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3))
257
+
258
+ fn_map_colmap = fn_map.astype(float)
259
+ fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn)
260
+
261
+ # TP
262
+ tp_map = imread(
263
+ model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem)
264
+ )
265
+ tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
266
+
267
+ tp_map_colmap = tp_map.astype(float)
268
+ tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
269
+
270
+ # TN
271
+ tn_map = imread(
272
+ model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem)
273
+ )
274
+ tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3))
275
+
276
+ tn_map_colmap = tn_map.astype(float)
277
+ tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn)
278
+
279
+ label_colmap = label.astype(float)
280
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
281
+ label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may)
282
+ label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma
283
+
284
+ # Combine masks
285
+ maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap
286
+ maps_ma = np.ma.masked_equal(maps, (0, 0, 0))
287
+ maps_ma = maps_ma.mask * img + maps_ma
288
+
289
+ ax.imshow(img)
290
+ ax.imshow(label_colmap_ma, alpha=0.5)
291
+ ax.imshow(maps_ma, alpha=0.5)
292
+ ax.axis("off")
293
+
294
+ # Annotation
295
+ ax.annotate(
296
+ xy=(0.05, 0.95),
297
+ xycoords="axes fraction",
298
+ xytext=(0.05, 0.95),
299
+ textcoords="axes fraction",
300
+ text=img_id,
301
+ fontsize="x-large",
302
+ verticalalignment="top",
303
+ color="white",
304
+ )
305
+
306
+ # Legend
307
+ if do_legend:
308
+ handles = []
309
+ lw = 1.0
310
+ handles.append(
311
+ mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
312
+ )
313
+ handles.append(
314
+ mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66)
315
+ )
316
+ handles.append(
317
+ mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66)
318
+ )
319
+ handles.append(
320
+ mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66)
321
+ )
322
+ handles.append(
323
+ mpatches.Patch(
324
+ facecolor=color_may, label="May-be-flooded", linewidth=lw, alpha=0.66
325
+ )
326
+ )
327
+ labels = ["TP", "TN", "FP", "FN", "May-be-flooded"]
328
+ ax.legend(
329
+ handles=handles,
330
+ labels=labels,
331
+ bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
332
+ ncol=5,
333
+ mode="expand",
334
+ fontsize="xx-small",
335
+ frameon=False,
336
+ )
337
+
338
+
339
+ def plot_edge_coherence(ax, img, label, pred, img_id, do_legend):
340
+ pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
341
+
342
+ ec, pred_ec, label_ec = edges_coherence_std_min(
343
+ np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood"))
344
+ )
345
+
346
+ ##################
347
+ # Edge distances #
348
+ ##################
349
+
350
+ # Location of edges
351
+ pred_ec_coord = np.argwhere(pred_ec > 0)
352
+ label_ec_coord = np.argwhere(label_ec > 0)
353
+
354
+ # Normalized pairwise distances between pred and label
355
+ dist_mat = np.divide(
356
+ euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0]
357
+ )
358
+
359
+ # Standard deviation of the minimum distance from pred to label
360
+ min_dist = np.min(dist_mat, axis=1) # noqa: F841
361
+
362
+ #############
363
+ # Make plot #
364
+ #############
365
+
366
+ pred_ec = np.tile(
367
+ np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
368
+ )
369
+ pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred)
370
+ pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841
371
+
372
+ label_ec = np.tile(
373
+ np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
374
+ )
375
+ label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must)
376
+ label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841
377
+ label_ec_colmap, color_must
378
+ )
379
+
380
+ # Combined pred and label edges
381
+ combined_ec = pred_ec_colmap + label_ec_colmap
382
+ combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0))
383
+ combined_ec_img = combined_ec_ma.mask * img + combined_ec
384
+
385
+ # Pred
386
+ pred_colmap = pred.astype(float)
387
+ pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
388
+ pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
389
+
390
+ # Must
391
+ label_colmap = label.astype(float)
392
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
393
+ label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must)
394
+
395
+ # TP
396
+ tp_map = imread(
397
+ model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem)
398
+ )
399
+ tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
400
+ tp_map_colmap = tp_map.astype(float)
401
+ tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
402
+ tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp)
403
+
404
+ # Combination
405
+ comb_pred = (
406
+ (pred_colmap_ma.mask ^ tp_map_colmap_ma.mask)
407
+ & tp_map_colmap_ma.mask
408
+ & combined_ec_ma.mask
409
+ ) * pred_colmap
410
+ comb_label = (
411
+ (label_colmap_ma.mask ^ pred_colmap_ma.mask)
412
+ & pred_colmap_ma.mask
413
+ & combined_ec_ma.mask
414
+ ) * label_colmap
415
+ comb_tp = combined_ec_ma.mask * tp_map_colmap.copy()
416
+ combined = comb_tp + comb_label + comb_pred
417
+ combined_ma = np.ma.masked_equal(combined, (0, 0, 0))
418
+ combined_ma = combined_ma.mask * combined_ec_img + combined_ma
419
+
420
+ ax.imshow(combined_ec_img, alpha=1)
421
+ ax.imshow(combined_ma, alpha=0.5)
422
+ ax.axis("off")
423
+
424
+ # Plot lines
425
+ idx_sort_x = np.argsort(pred_ec_coord[:, 1])
426
+ offset = 100
427
+ for idx in range(offset, pred_ec_coord.shape[0], offset):
428
+ y0, x0 = pred_ec_coord[idx_sort_x[idx], :]
429
+ argmin = np.argmin(dist_mat[idx_sort_x[idx]])
430
+ y1, x1 = label_ec_coord[argmin, :]
431
+ ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5)
432
+
433
+ # Annotation
434
+ ax.annotate(
435
+ xy=(0.05, 0.95),
436
+ xycoords="axes fraction",
437
+ xytext=(0.05, 0.95),
438
+ textcoords="axes fraction",
439
+ text=img_id,
440
+ fontsize="x-large",
441
+ verticalalignment="top",
442
+ color="white",
443
+ )
444
+ # Legend
445
+ if do_legend:
446
+ handles = []
447
+ lw = 1.0
448
+ handles.append(
449
+ mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
450
+ )
451
+ handles.append(
452
+ mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66)
453
+ )
454
+ handles.append(
455
+ mpatches.Patch(
456
+ facecolor=color_must, label="Must-be-flooded", linewidth=lw, alpha=0.66
457
+ )
458
+ )
459
+ labels = ["TP", "Prediction", "Must-be-flooded"]
460
+ ax.legend(
461
+ handles=handles,
462
+ labels=labels,
463
+ bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
464
+ ncol=3,
465
+ mode="expand",
466
+ fontsize="xx-small",
467
+ frameon=False,
468
+ )
469
+
470
+
471
+ def plot_images_metric(axes, metric, img_filename, img_id, do_legend):
472
+
473
+ # Read images
474
+ img_path = imgs_orig_path / img_filename
475
+ label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
476
+ img, label = crop_and_resize(img_path, label_path)
477
+ img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0
478
+ pred = imread(
479
+ model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem)
480
+ )
481
+
482
+ # Label
483
+ plot_labels(axes[0], img, label, img_id, do_legend)
484
+
485
+ # Prediction
486
+ plot_pred(axes[1], img, pred, img_id, do_legend)
487
+
488
+ # Correct / incorrect
489
+ if metric in ["error", "f05"]:
490
+ plot_correct_incorrect(axes[2], img_filename, img, label, img_id, do_legend)
491
+ # Edge coherence
492
+ elif metric == "edge_coherence":
493
+ plot_edge_coherence(axes[2], img, label, pred, img_id, do_legend)
494
+ else:
495
+ raise ValueError
496
+
497
+
498
+ def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images):
499
+
500
+ sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax)
501
+
502
+ # Set X-label
503
+ ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium")
504
+
505
+ # Set Y-label
506
+ ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium")
507
+
508
+ # Change spines
509
+ sns.despine(ax=ax, left=True, bottom=True)
510
+
511
+ annotate_scatterplot(ax, dict_images, x_metric, y_metric)
512
+
513
+
514
+ def scatterplot_metrics(ax, df, dict_images):
515
+
516
+ sns.scatterplot(data=df, x="error", y="f05", hue="edge_coherence", ax=ax)
517
+
518
+ # Set X-label
519
+ ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium")
520
+
521
+ # Set Y-label
522
+ ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium")
523
+
524
+ annotate_scatterplot(ax, dict_images, "error", "f05")
525
+
526
+ # Change spines
527
+ sns.despine(ax=ax, left=True, bottom=True)
528
+
529
+ # Set XY limits
530
+ xlim = ax.get_xlim()
531
+ ylim = ax.get_ylim()
532
+ ax.set_xlim([0.0, xlim[1]])
533
+ ax.set_ylim([ylim[0], 1.0])
534
+
535
+
536
+ def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1):
537
+ xlim = ax.get_xlim()
538
+ ylim = ax.get_ylim()
539
+ x_len = xlim[1] - xlim[0]
540
+ y_len = ylim[1] - ylim[0]
541
+ x_th = xlim[1] - x_len / 2.0
542
+ y_th = ylim[1] - y_len / 2.0
543
+ for text, d in dict_images.items():
544
+ x = d[x_metric]
545
+ y = d[y_metric]
546
+ x_text = x + x_len * offset if x < x_th else x - x_len * offset
547
+ y_text = y + y_len * offset if y < y_th else y - y_len * offset
548
+ ax.annotate(
549
+ xy=(x, y),
550
+ xycoords="data",
551
+ xytext=(x_text, y_text),
552
+ textcoords="data",
553
+ text=text,
554
+ arrowprops=dict(facecolor="black", shrink=0.05),
555
+ fontsize="medium",
556
+ color="black",
557
+ )
558
+
559
+
560
+ if __name__ == "__main__":
561
+ # -----------------------------
562
+ # ----- Parse arguments -----
563
+ # -----------------------------
564
+ args = parsed_args()
565
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
566
+
567
+ # Determine output dir
568
+ if args.output_dir is None:
569
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
570
+ else:
571
+ output_dir = Path(args.output_dir)
572
+ if not output_dir.exists():
573
+ output_dir.mkdir(parents=True, exist_ok=False)
574
+
575
+ # Store args
576
+ output_yml = output_dir / "labels.yml"
577
+ with open(output_yml, "w") as f:
578
+ yaml.dump(vars(args), f)
579
+
580
+ # Data dirs
581
+ imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
582
+ labels_path = Path(args.masker_test_set_dir) / "labels"
583
+
584
+ # Read CSV
585
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
586
+
587
+ # Select best model
588
+ df = df.loc[df.model_feats == args.best_model]
589
+ v_key, model_dir = df.model.unique()[0].split("/")
590
+ model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir
591
+
592
+ # Set up plot
593
+ sns.reset_orig()
594
+ sns.set(style="whitegrid")
595
+ plt.rcParams.update({"font.family": "serif"})
596
+ plt.rcParams.update(
597
+ {
598
+ "font.serif": [
599
+ "Computer Modern Roman",
600
+ "Times New Roman",
601
+ "Utopia",
602
+ "New Century Schoolbook",
603
+ "Century Schoolbook L",
604
+ "ITC Bookman",
605
+ "Bookman",
606
+ "Times",
607
+ "Palatino",
608
+ "Charter",
609
+ "serif" "Bitstream Vera Serif",
610
+ "DejaVu Serif",
611
+ ]
612
+ }
613
+ )
614
+
615
+ if args.seed:
616
+ np.random.seed(args.seed)
617
+ img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
618
+ dict_images = {}
619
+ idx = 0
620
+ for metric in metrics:
621
+
622
+ fig, axes = plt.subplots(nrows=2, ncols=3, dpi=200, figsize=(18, 12))
623
+
624
+ # Select best
625
+ if metric == "error":
626
+ ascending = True
627
+ else:
628
+ ascending = False
629
+ idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
630
+ srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
631
+ img_id = img_ids[idx]
632
+ dict_images.update({img_id: srs_sel})
633
+
634
+ # Read images
635
+ img_filename = srs_sel.filename
636
+
637
+ if not args.no_images:
638
+ axes_row = axes[0, :]
639
+ plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=True)
640
+
641
+ idx += 1
642
+
643
+ # Select worst
644
+ if metric == "error":
645
+ ascending = False
646
+ else:
647
+ ascending = True
648
+ idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
649
+ srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
650
+ img_id = img_ids[idx]
651
+ dict_images.update({img_id: srs_sel})
652
+
653
+ # Read images
654
+ img_filename = srs_sel.filename
655
+
656
+ if not args.no_images:
657
+ axes_row = axes[1, :]
658
+ plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=False)
659
+
660
+ idx += 1
661
+
662
+ # Save figure
663
+ output_fig = output_dir / "{}.png".format(metric)
664
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
665
+
666
+ fig = plt.figure(dpi=200)
667
+ scatterplot_metrics(fig.gca(), df, dict_images)
668
+
669
+ # fig, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5))
670
+ #
671
+ # scatterplot_metrics_pair(axes[0], df, 'error', 'f05', dict_images)
672
+ # scatterplot_metrics_pair(axes[1], df, 'error', 'edge_coherence', dict_images)
673
+ # scatterplot_metrics_pair(axes[2], df, 'f05', 'edge_coherence', dict_images)
674
+ #
675
+ output_fig = output_dir / "scatterplots.png"
676
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
figures/metrics_onefig.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This scripts plots examples of the images that get best and worse metrics
3
+ """
4
+ print("Imports...", end="")
5
+ import os
6
+ import sys
7
+ from argparse import ArgumentParser
8
+ from pathlib import Path
9
+
10
+ import matplotlib.patches as mpatches
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ import seaborn as sns
15
+ import yaml
16
+ from imageio import imread
17
+ from matplotlib.gridspec import GridSpec
18
+ from skimage.color import rgba2rgb
19
+ from sklearn.metrics.pairwise import euclidean_distances
20
+
21
+ sys.path.append("../")
22
+
23
+ from climategan.data import encode_mask_label
24
+ from climategan.eval_metrics import edges_coherence_std_min
25
+ from eval_masker import crop_and_resize
26
+
27
+ # -----------------------
28
+ # ----- Constants -----
29
+ # -----------------------
30
+
31
+ # Metrics
32
+ metrics = ["error", "f05", "edge_coherence"]
33
+
34
+ dict_metrics = {
35
+ "names": {
36
+ "tpr": "TPR, Recall, Sensitivity",
37
+ "tnr": "TNR, Specificity, Selectivity",
38
+ "fpr": "FPR",
39
+ "fpt": "False positives relative to image size",
40
+ "fnr": "FNR, Miss rate",
41
+ "fnt": "False negatives relative to image size",
42
+ "mpr": "May positive rate (MPR)",
43
+ "mnr": "May negative rate (MNR)",
44
+ "accuracy": "Accuracy (ignoring may)",
45
+ "error": "Error",
46
+ "f05": "F05 score",
47
+ "precision": "Precision",
48
+ "edge_coherence": "Edge coherence",
49
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
50
+ },
51
+ "key_metrics": ["error", "f05", "edge_coherence"],
52
+ }
53
+
54
+
55
+ # Colors
56
+ colorblind_palette = sns.color_palette("colorblind")
57
+ color_cannot = colorblind_palette[1]
58
+ color_must = colorblind_palette[2]
59
+ color_may = colorblind_palette[7]
60
+ color_pred = colorblind_palette[4]
61
+
62
+ icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
63
+ color_tp = icefire[0]
64
+ color_tn = icefire[1]
65
+ color_fp = icefire[4]
66
+ color_fn = icefire[3]
67
+
68
+
69
+ def parsed_args():
70
+ """
71
+ Parse and returns command-line args
72
+
73
+ Returns:
74
+ argparse.Namespace: the parsed arguments
75
+ """
76
+ parser = ArgumentParser()
77
+ parser.add_argument(
78
+ "--input_csv",
79
+ default="ablations_metrics_20210311.csv",
80
+ type=str,
81
+ help="CSV containing the results of the ablation study",
82
+ )
83
+ parser.add_argument(
84
+ "--output_dir",
85
+ default=None,
86
+ type=str,
87
+ help="Output directory",
88
+ )
89
+ parser.add_argument(
90
+ "--models_log_path",
91
+ default=None,
92
+ type=str,
93
+ help="Path containing the log files of the models",
94
+ )
95
+ parser.add_argument(
96
+ "--masker_test_set_dir",
97
+ default=None,
98
+ type=str,
99
+ help="Directory containing the test images",
100
+ )
101
+ parser.add_argument(
102
+ "--best_model",
103
+ default="dada, msd_spade, pseudo",
104
+ type=str,
105
+ help="The string identifier of the best model",
106
+ )
107
+ parser.add_argument(
108
+ "--dpi",
109
+ default=200,
110
+ type=int,
111
+ help="DPI for the output images",
112
+ )
113
+ parser.add_argument(
114
+ "--alpha",
115
+ default=0.5,
116
+ type=float,
117
+ help="Transparency of labels shade",
118
+ )
119
+ parser.add_argument(
120
+ "--percentile",
121
+ default=0.05,
122
+ type=float,
123
+ help="Transparency of labels shade",
124
+ )
125
+ parser.add_argument(
126
+ "--seed",
127
+ default=None,
128
+ type=int,
129
+ help="Bootstrap random seed, for reproducibility",
130
+ )
131
+ parser.add_argument(
132
+ "--no_images",
133
+ action="store_true",
134
+ default=False,
135
+ help="Do not generate images",
136
+ )
137
+
138
+ return parser.parse_args()
139
+
140
+
141
+ def map_color(arr, input_color, output_color, rtol=1e-09):
142
+ """
143
+ Maps one color to another
144
+ """
145
+ input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
146
+ output = arr.copy()
147
+ output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
148
+ return output
149
+
150
+
151
+ def plot_labels(ax, img, label, img_id, n_, add_title, do_legend):
152
+ label_colmap = label.astype(float)
153
+ label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
154
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
155
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
156
+
157
+ ax.imshow(img)
158
+ ax.imshow(label_colmap, alpha=0.5)
159
+ ax.axis("off")
160
+
161
+ if n_ in [1, 3, 5]:
162
+ color_ = "green"
163
+ else:
164
+ color_ = "red"
165
+
166
+ ax.text(
167
+ -0.15,
168
+ 0.5,
169
+ img_id,
170
+ color=color_,
171
+ fontweight="roman",
172
+ fontsize="x-large",
173
+ horizontalalignment="left",
174
+ verticalalignment="center",
175
+ transform=ax.transAxes,
176
+ )
177
+
178
+ if add_title:
179
+ ax.set_title("Labels", rotation=0, fontsize="x-large")
180
+
181
+
182
+ def plot_pred(ax, img, pred, img_id, add_title, do_legend):
183
+ pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
184
+
185
+ pred_colmap = pred.astype(float)
186
+ pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
187
+ pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
188
+ pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma
189
+
190
+ ax.imshow(img)
191
+ ax.imshow(pred_colmap_ma, alpha=0.5)
192
+ ax.axis("off")
193
+
194
+ if add_title:
195
+ ax.set_title("Prediction", rotation=0, fontsize="x-large")
196
+
197
+
198
+ def plot_correct_incorrect(
199
+ ax, img_filename, img, metric, label, img_id, n_, add_title, do_legend
200
+ ):
201
+ # FP
202
+ fp_map = imread(
203
+ model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem)
204
+ )
205
+ fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3))
206
+
207
+ fp_map_colmap = fp_map.astype(float)
208
+ fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp)
209
+
210
+ # FN
211
+ fn_map = imread(
212
+ model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem)
213
+ )
214
+ fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3))
215
+
216
+ fn_map_colmap = fn_map.astype(float)
217
+ fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn)
218
+
219
+ # TP
220
+ tp_map = imread(
221
+ model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem)
222
+ )
223
+ tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
224
+
225
+ tp_map_colmap = tp_map.astype(float)
226
+ tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
227
+
228
+ # TN
229
+ tn_map = imread(
230
+ model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem)
231
+ )
232
+ tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3))
233
+
234
+ tn_map_colmap = tn_map.astype(float)
235
+ tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn)
236
+
237
+ label_colmap = label.astype(float)
238
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
239
+ label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may)
240
+ label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma
241
+
242
+ # Combine masks
243
+ maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap
244
+ maps_ma = np.ma.masked_equal(maps, (0, 0, 0))
245
+ maps_ma = maps_ma.mask * img + maps_ma
246
+
247
+ ax.imshow(img)
248
+ ax.imshow(label_colmap_ma, alpha=0.5)
249
+ ax.imshow(maps_ma, alpha=0.5)
250
+ ax.axis("off")
251
+
252
+ if add_title:
253
+ ax.set_title("Metric", rotation=0, fontsize="x-large")
254
+
255
+
256
+ def plot_edge_coherence(ax, img, metric, label, pred, img_id, n_, add_title, do_legend):
257
+ pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
258
+
259
+ ec, pred_ec, label_ec = edges_coherence_std_min(
260
+ np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood"))
261
+ )
262
+
263
+ ##################
264
+ # Edge distances #
265
+ ##################
266
+
267
+ # Location of edges
268
+ pred_ec_coord = np.argwhere(pred_ec > 0)
269
+ label_ec_coord = np.argwhere(label_ec > 0)
270
+
271
+ # Normalized pairwise distances between pred and label
272
+ dist_mat = np.divide(
273
+ euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0]
274
+ )
275
+
276
+ # Standard deviation of the minimum distance from pred to label
277
+ min_dist = np.min(dist_mat, axis=1) # noqa: F841
278
+
279
+ #############
280
+ # Make plot #
281
+ #############
282
+
283
+ pred_ec = np.tile(
284
+ np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
285
+ )
286
+ pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred)
287
+ pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841
288
+
289
+ label_ec = np.tile(
290
+ np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
291
+ )
292
+ label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must)
293
+ label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841
294
+ label_ec_colmap, color_must
295
+ )
296
+
297
+ # Combined pred and label edges
298
+ combined_ec = pred_ec_colmap + label_ec_colmap
299
+ combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0))
300
+ combined_ec_img = combined_ec_ma.mask * img + combined_ec
301
+
302
+ # Pred
303
+ pred_colmap = pred.astype(float)
304
+ pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
305
+ pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
306
+
307
+ # Must
308
+ label_colmap = label.astype(float)
309
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
310
+ label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must)
311
+
312
+ # TP
313
+ tp_map = imread(
314
+ model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem)
315
+ )
316
+ tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
317
+ tp_map_colmap = tp_map.astype(float)
318
+ tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
319
+ tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp)
320
+
321
+ # Combination
322
+ comb_pred = (
323
+ (pred_colmap_ma.mask ^ tp_map_colmap_ma.mask)
324
+ & tp_map_colmap_ma.mask
325
+ & combined_ec_ma.mask
326
+ ) * pred_colmap
327
+ comb_label = (
328
+ (label_colmap_ma.mask ^ pred_colmap_ma.mask)
329
+ & pred_colmap_ma.mask
330
+ & combined_ec_ma.mask
331
+ ) * label_colmap
332
+ comb_tp = combined_ec_ma.mask * tp_map_colmap.copy()
333
+ combined = comb_tp + comb_label + comb_pred
334
+ combined_ma = np.ma.masked_equal(combined, (0, 0, 0))
335
+ combined_ma = combined_ma.mask * combined_ec_img + combined_ma
336
+
337
+ ax.imshow(combined_ec_img, alpha=1)
338
+ ax.imshow(combined_ma, alpha=0.5)
339
+ ax.axis("off")
340
+
341
+ # Plot lines
342
+ idx_sort_x = np.argsort(pred_ec_coord[:, 1])
343
+ offset = 100
344
+ for idx in range(offset, pred_ec_coord.shape[0], offset):
345
+ y0, x0 = pred_ec_coord[idx_sort_x[idx], :]
346
+ argmin = np.argmin(dist_mat[idx_sort_x[idx]])
347
+ y1, x1 = label_ec_coord[argmin, :]
348
+ ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5)
349
+
350
+ if add_title:
351
+ ax.set_title("Metric", rotation=0, fontsize="x-large")
352
+
353
+
354
+ def plot_images_metric(
355
+ axes, metric, img_filename, img_id, n_, srs_sel, add_title, do_legend
356
+ ):
357
+
358
+ # Read images
359
+ img_path = imgs_orig_path / img_filename
360
+ label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
361
+ img, label = crop_and_resize(img_path, label_path)
362
+ img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0
363
+
364
+ pred = imread(
365
+ model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem)
366
+ )
367
+
368
+ # Label
369
+ plot_labels(axes[0], img, label, img_id, n_, add_title, do_legend)
370
+
371
+ # Prediction
372
+ plot_pred(axes[1], img, pred, img_id, add_title, do_legend)
373
+
374
+ # Correct / incorrect
375
+ if metric in ["error", "f05"]:
376
+ plot_correct_incorrect(
377
+ axes[2],
378
+ img_filename,
379
+ img,
380
+ metric,
381
+ label,
382
+ img_id,
383
+ n_,
384
+ add_title,
385
+ do_legend=False,
386
+ )
387
+ handles = []
388
+ lw = 1.0
389
+ handles.append(
390
+ mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
391
+ )
392
+ handles.append(
393
+ mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66)
394
+ )
395
+ handles.append(
396
+ mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66)
397
+ )
398
+ handles.append(
399
+ mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66)
400
+ )
401
+ handles.append(
402
+ mpatches.Patch(
403
+ facecolor=color_may,
404
+ label="May-be-flooded",
405
+ linewidth=lw,
406
+ alpha=0.66,
407
+ )
408
+ )
409
+ labels = ["TP", "TN", "FP", "FN", "May-be-flooded"]
410
+ if metric == "error":
411
+ if n_ in [1, 3, 5]:
412
+ title = "Low error rate"
413
+ else:
414
+ title = "High error rate"
415
+ else:
416
+ if n_ in [1, 3, 5]:
417
+ title = "High F05 score"
418
+ else:
419
+ title = "Low F05 score"
420
+ # Edge coherence
421
+ elif metric == "edge_coherence":
422
+ plot_edge_coherence(
423
+ axes[2], img, metric, label, pred, img_id, n_, add_title, do_legend=False
424
+ )
425
+ handles = []
426
+ lw = 1.0
427
+ handles.append(
428
+ mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
429
+ )
430
+ handles.append(
431
+ mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66)
432
+ )
433
+ handles.append(
434
+ mpatches.Patch(
435
+ facecolor=color_must,
436
+ label="Must-be-flooded",
437
+ linewidth=lw,
438
+ alpha=0.66,
439
+ )
440
+ )
441
+ labels = ["TP", "Prediction", "Must-be-flooded"]
442
+ if n_ in [1, 3, 5]:
443
+ title = "High edge coherence"
444
+ else:
445
+ title = "Low edge coherence"
446
+
447
+ else:
448
+ raise ValueError
449
+
450
+ labels_values_title = "Error: {:.4f} \nFO5: {:.4f} \nEdge coherence: {:.4f}".format(
451
+ srs_sel.error, srs_sel.f05, srs_sel.edge_coherence
452
+ )
453
+
454
+ plot_legend(axes[3], img, handles, labels, labels_values_title, title)
455
+
456
+
457
+ def plot_legend(ax, img, handles, labels, labels_values_title, title):
458
+ img_ = np.zeros_like(img, dtype=np.uint8)
459
+ img_.fill(255)
460
+ ax.imshow(img_)
461
+ ax.axis("off")
462
+
463
+ leg1 = ax.legend(
464
+ handles=handles,
465
+ labels=labels,
466
+ title=title,
467
+ title_fontsize="medium",
468
+ labelspacing=0.6,
469
+ loc="upper left",
470
+ fontsize="x-small",
471
+ frameon=False,
472
+ )
473
+ leg1._legend_box.align = "left"
474
+
475
+ leg2 = ax.legend(
476
+ title=labels_values_title,
477
+ title_fontsize="small",
478
+ loc="lower left",
479
+ frameon=False,
480
+ )
481
+ leg2._legend_box.align = "left"
482
+
483
+ ax.add_artist(leg1)
484
+
485
+
486
+ def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images):
487
+
488
+ sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax)
489
+
490
+ # Set X-label
491
+ ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium")
492
+
493
+ # Set Y-label
494
+ ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium")
495
+
496
+ # Change spines
497
+ sns.despine(ax=ax, left=True, bottom=True)
498
+
499
+ annotate_scatterplot(ax, dict_images, x_metric, y_metric)
500
+
501
+
502
+ def scatterplot_metrics(ax, df, df_all, dict_images, plot_all=False):
503
+
504
+ # Other
505
+ if plot_all:
506
+ sns.scatterplot(
507
+ data=df_all.loc[df_all.ground == True],
508
+ x="error", y="f05", hue="edge_coherence", ax=ax,
509
+ marker='+', alpha=0.25)
510
+ sns.scatterplot(
511
+ data=df_all.loc[df_all.instagan == True],
512
+ x="error", y="f05", hue="edge_coherence", ax=ax,
513
+ marker='x', alpha=0.25)
514
+ sns.scatterplot(
515
+ data=df_all.loc[(df_all.instagan == False) & (df_all.instagan == False) &
516
+ (df_all.model_feats != args.best_model)],
517
+ x="error", y="f05", hue="edge_coherence", ax=ax,
518
+ marker='s', alpha=0.25)
519
+
520
+ # Best model
521
+ cmap_ = sns.cubehelix_palette(as_cmap=True)
522
+ sns.scatterplot(
523
+ data=df, x="error", y="f05", hue="edge_coherence", ax=ax, palette=cmap_
524
+ )
525
+
526
+ norm = plt.Normalize(df["edge_coherence"].min(), df["edge_coherence"].max())
527
+ sm = plt.cm.ScalarMappable(cmap=cmap_, norm=norm)
528
+ sm.set_array([])
529
+
530
+ # Remove the legend and add a colorbar
531
+ ax.get_legend().remove()
532
+ ax_cbar = ax.figure.colorbar(sm)
533
+ ax_cbar.set_label("Edge coherence", labelpad=8)
534
+
535
+ # Set X-label
536
+ ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium")
537
+
538
+ # Set Y-label
539
+ ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium")
540
+
541
+ annotate_scatterplot(ax, dict_images, "error", "f05")
542
+
543
+ # Change spines
544
+ sns.despine(ax=ax, left=True, bottom=True)
545
+
546
+ # Set XY limits
547
+ xlim = ax.get_xlim()
548
+ ylim = ax.get_ylim()
549
+ ax.set_xlim([0.0, xlim[1]])
550
+ ax.set_ylim([ylim[0], 1.0])
551
+
552
+
553
+ def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1):
554
+ xlim = ax.get_xlim()
555
+ ylim = ax.get_ylim()
556
+ x_len = xlim[1] - xlim[0]
557
+ y_len = ylim[1] - ylim[0]
558
+ x_th = xlim[1] - x_len / 2.0
559
+ y_th = ylim[1] - y_len / 2.0
560
+ for text, d in dict_images.items():
561
+ if text in ["B", "D", "F"]:
562
+ x = d[x_metric]
563
+ y = d[y_metric]
564
+
565
+ x_text = x + x_len * offset if x < x_th else x - x_len * offset
566
+ y_text = y + y_len * offset if y < y_th else y - y_len * offset
567
+
568
+ ax.annotate(
569
+ xy=(x, y),
570
+ xycoords="data",
571
+ xytext=(x_text, y_text),
572
+ textcoords="data",
573
+ text=text,
574
+ arrowprops=dict(facecolor="black", shrink=0.05),
575
+ fontsize="medium",
576
+ color="black",
577
+ )
578
+ elif text == "A":
579
+ x = (
580
+ dict_images["A"][x_metric]
581
+ + dict_images["C"][x_metric]
582
+ + dict_images["E"][x_metric]
583
+ ) / 3
584
+ y = (
585
+ dict_images["A"][y_metric]
586
+ + dict_images["C"][y_metric]
587
+ + dict_images["E"][y_metric]
588
+ ) / 3
589
+
590
+ x_text = x + x_len * 2 * offset if x < x_th else x - x_len * 2 * offset
591
+ y_text = (
592
+ y + y_len * 0.45 * offset if y < y_th else y - y_len * 0.45 * offset
593
+ )
594
+
595
+ ax.annotate(
596
+ xy=(x, y),
597
+ xycoords="data",
598
+ xytext=(x_text, y_text),
599
+ textcoords="data",
600
+ text="A, C, E",
601
+ arrowprops=dict(facecolor="black", shrink=0.05),
602
+ fontsize="medium",
603
+ color="black",
604
+ )
605
+
606
+
607
+ if __name__ == "__main__":
608
+ # -----------------------------
609
+ # ----- Parse arguments -----
610
+ # -----------------------------
611
+ args = parsed_args()
612
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
613
+
614
+ # Determine output dir
615
+ if args.output_dir is None:
616
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
617
+ else:
618
+ output_dir = Path(args.output_dir)
619
+ if not output_dir.exists():
620
+ output_dir.mkdir(parents=True, exist_ok=False)
621
+
622
+ # Store args
623
+ output_yml = output_dir / "labels.yml"
624
+ with open(output_yml, "w") as f:
625
+ yaml.dump(vars(args), f)
626
+
627
+ # Data dirs
628
+ imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
629
+ labels_path = Path(args.masker_test_set_dir) / "labels"
630
+
631
+ # Read CSV
632
+ df_all = pd.read_csv(args.input_csv, index_col="model_img_idx")
633
+
634
+ # Select best model
635
+ df = df_all.loc[df_all.model_feats == args.best_model]
636
+ v_key, model_dir = df.model.unique()[0].split("/")
637
+ model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir
638
+
639
+ # Set up plot
640
+ sns.reset_orig()
641
+ sns.set(style="whitegrid")
642
+ plt.rcParams.update({"font.family": "serif"})
643
+ plt.rcParams.update(
644
+ {
645
+ "font.serif": [
646
+ "Computer Modern Roman",
647
+ "Times New Roman",
648
+ "Utopia",
649
+ "New Century Schoolbook",
650
+ "Century Schoolbook L",
651
+ "ITC Bookman",
652
+ "Bookman",
653
+ "Times",
654
+ "Palatino",
655
+ "Charter",
656
+ "serif" "Bitstream Vera Serif",
657
+ "DejaVu Serif",
658
+ ]
659
+ }
660
+ )
661
+
662
+ if args.seed:
663
+ np.random.seed(args.seed)
664
+ img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
665
+ dict_images = {}
666
+ idx = 0
667
+
668
+ # Define grid of subplots
669
+ grid_vmargin = 0.03 # Extent of the vertical margin between metric grids
670
+ ax_hspace = 0.04 # Extent of the vertical space between axes of same grid
671
+ ax_wspace = 0.05 # Extent of the horizontal space between axes of same grid
672
+ n_grids = len(metrics)
673
+ n_cols = 4
674
+ n_rows = 2
675
+ h_grid = (1.0 / n_grids) - ((n_grids - 1) * grid_vmargin) / n_grids
676
+
677
+ fig1 = plt.figure(dpi=200, figsize=(11, 13))
678
+
679
+ n_ = 0
680
+ add_title = False
681
+ for metric_id, metric in enumerate(metrics):
682
+
683
+ # Create grid
684
+ top_grid = 1.0 - metric_id * h_grid - metric_id * grid_vmargin
685
+ bottom_grid = top_grid - h_grid
686
+ gridspec = GridSpec(
687
+ n_rows,
688
+ n_cols,
689
+ wspace=ax_wspace,
690
+ hspace=ax_hspace,
691
+ bottom=bottom_grid,
692
+ top=top_grid,
693
+ )
694
+
695
+ # Select best
696
+ if metric == "error":
697
+ ascending = True
698
+ else:
699
+ ascending = False
700
+ idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
701
+ srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
702
+ img_id = img_ids[idx]
703
+ dict_images.update({img_id: srs_sel})
704
+ # Read images
705
+ img_filename = srs_sel.filename
706
+
707
+ axes_row = [fig1.add_subplot(gridspec[0, c]) for c in range(n_cols)]
708
+ if not args.no_images:
709
+ n_ += 1
710
+ if metric_id == 0:
711
+ add_title = True
712
+ plot_images_metric(
713
+ axes_row,
714
+ metric,
715
+ img_filename,
716
+ img_id,
717
+ n_,
718
+ srs_sel,
719
+ add_title=add_title,
720
+ do_legend=False,
721
+ )
722
+ add_title = False
723
+
724
+ idx += 1
725
+ print("1 more row done.")
726
+ # Select worst
727
+ if metric == "error":
728
+ ascending = False
729
+ else:
730
+ ascending = True
731
+ idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
732
+ srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
733
+ img_id = img_ids[idx]
734
+ dict_images.update({img_id: srs_sel})
735
+ # Read images
736
+ img_filename = srs_sel.filename
737
+
738
+ axes_row = [fig1.add_subplot(gridspec[1, c]) for c in range(n_cols)]
739
+ if not args.no_images:
740
+ n_ += 1
741
+ plot_images_metric(
742
+ axes_row,
743
+ metric,
744
+ img_filename,
745
+ img_id,
746
+ n_,
747
+ srs_sel,
748
+ add_title=add_title,
749
+ do_legend=False,
750
+ )
751
+
752
+ idx += 1
753
+ print("1 more row done.")
754
+
755
+ output_fig = output_dir / "all_metrics.png"
756
+
757
+ fig1.tight_layout() # (pad=1.5) #
758
+ fig1.savefig(output_fig, dpi=fig1.dpi, bbox_inches="tight")
759
+
760
+ # Scatter plot
761
+ fig2 = plt.figure(dpi=200)
762
+
763
+ scatterplot_metrics(fig2.gca(), df, df_all, dict_images)
764
+
765
+ # fig2, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5))
766
+ #
767
+ # scatterplot_metrics_pair(axes[0], df, "error", "f05", dict_images)
768
+ # scatterplot_metrics_pair(axes[1], df, "error", "edge_coherence", dict_images)
769
+ # scatterplot_metrics_pair(axes[2], df, "f05", "edge_coherence", dict_images)
770
+
771
+ output_fig = output_dir / "scatterplots.png"
772
+ fig2.savefig(output_fig, dpi=fig2.dpi, bbox_inches="tight")
inferences.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
2
+ # thank you @NimaBoscarino
3
+
4
+ import torch
5
+ from skimage.color import rgba2rgb
6
+ from skimage.transform import resize
7
+ import numpy as np
8
+
9
+ from climategan.trainer import Trainer
10
+
11
+
12
+ def uint8(array):
13
+ """
14
+ convert an array to np.uint8 (does not rescale or anything else than changing dtype)
15
+ Args:
16
+ array (np.array): array to modify
17
+ Returns:
18
+ np.array(np.uint8): converted array
19
+ """
20
+ return array.astype(np.uint8)
21
+
22
+
23
+ def resize_and_crop(img, to=640):
24
+ """
25
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
26
+ is `to`, then crops this resized image in its center so that the output is `to x to`
27
+ without aspect ratio distortion
28
+ Args:
29
+ img (np.array): np.uint8 255 image
30
+ Returns:
31
+ np.array: [0, 1] np.float32 image
32
+ """
33
+ # resize keeping aspect ratio: smallest dim is 640
34
+ h, w = img.shape[:2]
35
+ if h < w:
36
+ size = (to, int(to * w / h))
37
+ else:
38
+ size = (int(to * h / w), to)
39
+
40
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
41
+ r_img = uint8(r_img)
42
+
43
+ # crop in the center
44
+ H, W = r_img.shape[:2]
45
+
46
+ top = (H - to) // 2
47
+ left = (W - to) // 2
48
+
49
+ rc_img = r_img[top : top + to, left : left + to, :]
50
+
51
+ return rc_img / 255.0
52
+
53
+
54
+ def to_m1_p1(img):
55
+ """
56
+ rescales a [0, 1] image to [-1, +1]
57
+ Args:
58
+ img (np.array): float32 numpy array of an image in [0, 1]
59
+ i (int): Index of the image being rescaled
60
+ Raises:
61
+ ValueError: If the image is not in [0, 1]
62
+ Returns:
63
+ np.array(np.float32): array in [-1, +1]
64
+ """
65
+ if img.min() >= 0 and img.max() <= 1:
66
+ return (img.astype(np.float32) - 0.5) * 2
67
+ raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
68
+
69
+
70
+ # No need to do any timing in this, since it's just for the HF Space
71
+ class ClimateGAN:
72
+ def __init__(self, model_path) -> None:
73
+ torch.set_grad_enabled(False)
74
+ self.target_size = 640
75
+ self.trainer = Trainer.resume_from_path(
76
+ model_path,
77
+ setup=True,
78
+ inference=True,
79
+ new_exp=None,
80
+ )
81
+
82
+ # Does all three inferences at the moment.
83
+ def inference(self, orig_image):
84
+ image = self._preprocess_image(orig_image)
85
+
86
+ # Retrieve numpy events as a dict {event: array[BxHxWxC]}
87
+ outputs = self.trainer.infer_all(
88
+ image,
89
+ numpy=True,
90
+ bin_value=0.5,
91
+ )
92
+
93
+ return (
94
+ outputs["flood"].squeeze(),
95
+ outputs["wildfire"].squeeze(),
96
+ outputs["smog"].squeeze(),
97
+ )
98
+
99
+ def _preprocess_image(self, img):
100
+ # rgba to rgb
101
+ data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
102
+
103
+ # to args.target_size
104
+ data = resize_and_crop(data, self.target_size)
105
+
106
+ # resize() produces [0, 1] images, rescale to [-1, 1]
107
+ data = to_m1_p1(data)
108
+ return data
requirements-3.8.2.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict==2.4.0
2
+ APScheduler==3.7.0
3
+ attrs==21.2.0
4
+ backcall==0.2.0
5
+ Brotli==1.0.9
6
+ certifi==2021.5.30
7
+ charset-normalizer==2.0.4
8
+ click==8.0.1
9
+ codecarbon==1.2.0
10
+ comet-ml==3.15.3
11
+ configobj==5.0.6
12
+ cycler==0.10.0
13
+ dash==2.0.0
14
+ dash-bootstrap-components==0.13.0
15
+ dash-core-components==2.0.0
16
+ dash-html-components==2.0.0
17
+ dash-table==5.0.0
18
+ dataclasses==0.6
19
+ decorator==5.0.9
20
+ dulwich==0.20.25
21
+ everett==2.0.1
22
+ filelock==3.0.12
23
+ fire==0.4.0
24
+ Flask==2.0.1
25
+ Flask-Compress==1.10.1
26
+ future==0.18.2
27
+ gdown==3.13.0
28
+ hydra-core==0.11.3
29
+ idna==3.2
30
+ imageio==2.9.0
31
+ ipython==7.27.0
32
+ itsdangerous==2.0.1
33
+ jedi==0.18.0
34
+ Jinja2==3.0.1
35
+ joblib==1.0.1
36
+ jsonschema==3.2.0
37
+ kiwisolver==1.3.2
38
+ kornia==0.5.10
39
+ MarkupSafe==2.0.1
40
+ matplotlib==3.4.3
41
+ matplotlib-inline==0.1.2
42
+ networkx==2.6.2
43
+ numpy==1.21.2
44
+ nvidia-ml-py3==7.352.0
45
+ omegaconf==1.4.1
46
+ opencv-python==4.5.3.56
47
+ packaging==21.0
48
+ pandas==1.3.2
49
+ parso==0.8.2
50
+ pexpect==4.8.0
51
+ pickleshare==0.7.5
52
+ Pillow==8.3.2
53
+ plotly==5.3.1
54
+ prompt-toolkit==3.0.20
55
+ ptyprocess==0.7.0
56
+ py-cpuinfo==8.0.0
57
+ Pygments==2.10.0
58
+ pynvml==11.0.0
59
+ pyparsing==2.4.7
60
+ pyrsistent==0.18.0
61
+ PySocks==1.7.1
62
+ python-dateutil==2.8.2
63
+ pytorch-ranger==0.1.1
64
+ pytz==2021.1
65
+ PyWavelets==1.1.1
66
+ PyYAML==5.4.1
67
+ requests==2.26.0
68
+ requests-toolbelt==0.9.1
69
+ scikit-image==0.18.3
70
+ scikit-learn==0.24.2
71
+ scipy==1.7.1
72
+ seaborn==0.11.2
73
+ semantic-version==2.8.5
74
+ six==1.16.0
75
+ tenacity==8.0.1
76
+ termcolor==1.1.0
77
+ threadpoolctl==2.2.0
78
+ tifffile==2021.8.30
79
+ torch==1.7.1
80
+ torch-optimizer==0.1.0
81
+ torchvision==0.8.2
82
+ tqdm==4.62.2
83
+ traitlets==5.1.0
84
+ typing-extensions==3.10.0.2
85
+ tzlocal==2.1
86
+ urllib3==1.26.6
87
+ wcwidth==0.2.5
88
+ websocket-client==1.2.1
89
+ Werkzeug==2.0.1
90
+ wrapt==1.12.1
91
+ wurlitzer==3.0.2
requirements-any.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict
2
+ codecarbon
3
+ comet_ml
4
+ hydra-core==0.11.3
5
+ kornia
6
+ omegaconf==1.4.1
7
+ matplotlib
8
+ numpy
9
+ opencv-python
10
+ packaging
11
+ pandas
12
+ PyYAML
13
+ scikit-image
14
+ scikit-learn
15
+ scipy
16
+ seaborn
17
+ torch==1.7.0
18
+ torch-optimizer
19
+ torchvision==0.8.1
20
+ tqdm
sbatch.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import itertools
3
+ import os
4
+ import re
5
+ import subprocess
6
+ import sys
7
+ from collections import defaultdict
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import yaml
12
+
13
+
14
+ def flatten_conf(conf, to={}, parents=[]):
15
+ """
16
+ Flattens a configuration dict: nested dictionaries are flattened
17
+ as key1.key2.key3 = value
18
+
19
+ conf.yaml:
20
+ ```yaml
21
+ a: 1
22
+ b:
23
+ c: 2
24
+ d:
25
+ e: 3
26
+ g:
27
+ sample: sequential
28
+ from: [4, 5]
29
+ ```
30
+
31
+ Is flattened to
32
+
33
+ {
34
+ "a": 1,
35
+ "b.c": 2,
36
+ "b.d.e": 3,
37
+ "b.g": {
38
+ "sample": "sequential",
39
+ "from": [4, 5]
40
+ }
41
+ }
42
+
43
+ Does not affect sampling dicts.
44
+
45
+ Args:
46
+ conf (dict): the configuration to flatten
47
+ new (dict, optional): the target flatenned dict. Defaults to {}.
48
+ parents (list, optional): a final value's list of parents. Defaults to [].
49
+ """
50
+ for k, v in conf.items():
51
+ if isinstance(v, dict) and "sample" not in v:
52
+ flatten_conf(v, to, parents + [k])
53
+ else:
54
+ new_k = ".".join([str(p) for p in parents + [k]])
55
+ to[new_k] = v
56
+
57
+
58
+ def env_to_path(path):
59
+ """Transorms an environment variable mention in a json
60
+ into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds
61
+
62
+ Args:
63
+ path (str): path potentially containing the env variable
64
+
65
+ """
66
+ path_elements = path.split("/")
67
+ new_path = []
68
+ for el in path_elements:
69
+ if "$" in el:
70
+ new_path.append(os.environ[el.replace("$", "")])
71
+ else:
72
+ new_path.append(el)
73
+ return "/".join(new_path)
74
+
75
+
76
+ class C:
77
+ HEADER = "\033[95m"
78
+ OKBLUE = "\033[94m"
79
+ OKGREEN = "\033[92m"
80
+ WARNING = "\033[93m"
81
+ FAIL = "\033[91m"
82
+ ENDC = "\033[0m"
83
+ BOLD = "\033[1m"
84
+ UNDERLINE = "\033[4m"
85
+ ITALIC = "\33[3m"
86
+ BEIGE = "\33[36m"
87
+
88
+
89
+ def escape_path(path):
90
+ p = str(path)
91
+ return p.replace(" ", "\ ").replace("(", "\(").replace(")", "\)") # noqa: W605
92
+
93
+
94
+ def warn(*args, **kwargs):
95
+ print("{}{}{}".format(C.WARNING, " ".join(args), C.ENDC), **kwargs)
96
+
97
+
98
+ def parse_jobID(command_output):
99
+ """
100
+ get job id from successful sbatch command output like
101
+ `Submitted batch job 599583`
102
+
103
+ Args:
104
+ command_output (str): sbatch command's output
105
+
106
+ Returns:
107
+ int: the slurm job's ID
108
+ """
109
+ command_output = command_output.strip()
110
+ if isinstance(command_output, str):
111
+ if "Submitted batch job" in command_output:
112
+ return int(command_output.split()[-1])
113
+
114
+ return -1
115
+
116
+
117
+ def now():
118
+ return str(datetime.datetime.now()).replace(" ", "_")
119
+
120
+
121
+ def cols():
122
+ try:
123
+ col = os.get_terminal_size().columns
124
+ except Exception:
125
+ col = 50
126
+ return col
127
+
128
+
129
+ def print_box(txt):
130
+ if not txt:
131
+ txt = "{}{}ERROR ⇪{}".format(C.BOLD, C.FAIL, C.ENDC)
132
+ lt = 7
133
+ else:
134
+ lt = len(txt)
135
+ nlt = lt + 12
136
+ txt = "|" + " " * 5 + txt + " " * 5 + "|"
137
+ line = "-" * nlt
138
+ empty = "|" + " " * (nlt - 2) + "|"
139
+ print(line)
140
+ print(empty)
141
+ print(txt)
142
+ print(empty)
143
+ print(line)
144
+
145
+
146
+ def print_header(idx):
147
+ b = C.BOLD
148
+ bl = C.OKBLUE
149
+ e = C.ENDC
150
+ char = "≡"
151
+ c = cols()
152
+
153
+ txt = " " * 20
154
+ txt += f"{b}{bl}Run {idx}{e}"
155
+ txt += " " * 20
156
+ ln = len(txt) - len(b) - len(bl) - len(e)
157
+ t = int(np.floor((c - ln) / 2))
158
+ tt = int(np.ceil((c - ln) / 2))
159
+
160
+ print(char * c)
161
+ print(char * t + " " * ln + char * tt)
162
+ print(char * t + txt + char * tt)
163
+ print(char * t + " " * ln + char * tt)
164
+ print(char * c)
165
+
166
+
167
+ def print_footer():
168
+ c = cols()
169
+ char = "﹎"
170
+ print()
171
+ print(char * (c // len(char)))
172
+ print()
173
+ print(" " * (c // 2) + "•" + " " * (c - c // 2 - 1))
174
+ print()
175
+
176
+
177
+ def extend_summary(summary, tmp_train_args_dict, tmp_template_dict, exclude=[]):
178
+ exclude = set(exclude)
179
+ if summary is None:
180
+ summary = defaultdict(list)
181
+ for k, v in tmp_template_dict.items():
182
+ if k not in exclude:
183
+ summary[k].append(v)
184
+ for k, v in tmp_train_args_dict.items():
185
+ if k not in exclude:
186
+ if isinstance(v, list):
187
+ v = str(v)
188
+ summary[k].append(v)
189
+ return summary
190
+
191
+
192
+ def search_summary_table(summary, summary_dir=None):
193
+ # filter out constant values
194
+ summary = {k: v for k, v in summary.items() if len(set(v)) > 1}
195
+
196
+ # if everything is constant: no summary
197
+ if not summary:
198
+ return None, None
199
+
200
+ # find number of searches
201
+ n_searches = len(list(summary.values())[0])
202
+
203
+ # print section title
204
+ print(
205
+ "{}{}{}Varying values across {} experiments:{}\n".format(
206
+ C.OKBLUE,
207
+ C.BOLD,
208
+ C.UNDERLINE,
209
+ n_searches,
210
+ C.ENDC,
211
+ )
212
+ )
213
+
214
+ # first column holds the Exp. number
215
+ first_col = {
216
+ "len": 8, # length of a column, to split columns according to terminal width
217
+ "str": ["| Exp. |", "|:----:|"]
218
+ + [
219
+ "| {0:^{1}} |".format(i, 4) for i in range(n_searches)
220
+ ], # list of values to print
221
+ }
222
+
223
+ print_columns = [[first_col]]
224
+ file_columns = [first_col]
225
+ for k in sorted(summary.keys()):
226
+ v = summary[k]
227
+ col_title = f" {k} |"
228
+ col_blank_line = f":{'-' * len(k)}-|"
229
+ col_values = [
230
+ " {0:{1}} |".format(
231
+ crop_string(
232
+ str(crop_float(v[idx], min([5, len(k) - 2]))), len(k)
233
+ ), # crop floats and long strings
234
+ len(k),
235
+ )
236
+ for idx in range(len(v))
237
+ ]
238
+
239
+ # create column object
240
+ col = {"len": len(k) + 3, "str": [col_title, col_blank_line] + col_values}
241
+
242
+ # if adding a new column would overflow the terminal and mess up printing, start
243
+ # new set of columns
244
+ if sum(c["len"] for c in print_columns[-1]) + col["len"] >= cols():
245
+ print_columns.append([first_col])
246
+
247
+ # store current column to latest group of columns
248
+ print_columns[-1].append(col)
249
+ file_columns.append(col)
250
+
251
+ print_table = ""
252
+ # print each column group individually
253
+ for colgroup in print_columns:
254
+ # print columns line by line
255
+ for i in range(n_searches + 2):
256
+ # get value of column for current line i
257
+ for col in colgroup:
258
+ print_table += col["str"][i]
259
+ # next line for current columns
260
+ print_table += "\n"
261
+
262
+ # new lines for new column group
263
+ print_table += "\n"
264
+
265
+ file_table = ""
266
+ for i in range(n_searches + 2):
267
+ # get value of column for current line i
268
+ for col in file_columns:
269
+ file_table += col["str"][i]
270
+ # next line for current columns
271
+ file_table += "\n"
272
+
273
+ summary_path = None
274
+ if summary_dir is not None:
275
+ summary_path = summary_dir / (now() + ".md")
276
+ with summary_path.open("w") as f:
277
+ f.write(file_table.strip())
278
+
279
+ return print_table, summary_path
280
+
281
+
282
+ def clean_arg(v):
283
+ """
284
+ chain cleaning function
285
+
286
+ Args:
287
+ v (any): arg to pass to train.py
288
+
289
+ Returns:
290
+ str: parsed value to string
291
+ """
292
+ return stringify_list(crop_float(quote_string(resolve_env(v))))
293
+
294
+
295
+ def resolve_env(v):
296
+ """
297
+ resolve env variables in paths
298
+
299
+ Args:
300
+ v (any): arg to pass to train.py
301
+
302
+ Returns:
303
+ str: try and resolve an env variable
304
+ """
305
+ if isinstance(v, str):
306
+ try:
307
+ if "$" in v:
308
+ if "/" in v:
309
+ v = env_to_path(v)
310
+ else:
311
+ _v = os.environ.get(v)
312
+ if _v is not None:
313
+ v = _v
314
+ except Exception:
315
+ pass
316
+ return v
317
+
318
+
319
+ def stringify_list(v):
320
+ """
321
+ Stringify list (with double quotes) so that it can be passed a an argument
322
+ to train.py's hydra command-line parsing
323
+
324
+ Args:
325
+ v (any): value to clean
326
+
327
+ Returns:
328
+ any: type of v, str if v was a list
329
+ """
330
+ if isinstance(v, list):
331
+ return '"{}"'.format(str(v).replace('"', "'"))
332
+ if isinstance(v, str):
333
+ if v.startswith("[") and v.endswith("]"):
334
+ return f'"{v}"'
335
+ return v
336
+
337
+
338
+ def quote_string(v):
339
+ """
340
+ Add double quotes around string if it contains a " " or an =
341
+
342
+ Args:
343
+ v (any): value to clean
344
+
345
+ Returns:
346
+ any: type of v, quoted if v is a string with " " or =
347
+ """
348
+ if isinstance(v, str):
349
+ if " " in v or "=" in v:
350
+ return f'"{v}"'
351
+ return v
352
+
353
+
354
+ def crop_float(v, k=5):
355
+ """
356
+ If v is a float, crop precision to 5 digits and return v as a str
357
+
358
+ Args:
359
+ v (any): value to crop if float
360
+
361
+ Returns:
362
+ any: cropped float as str if v is a float, original v otherwise
363
+ """
364
+ if isinstance(v, float):
365
+ return "{0:.{1}g}".format(v, k)
366
+ return v
367
+
368
+
369
+ def compute_n_search(conf):
370
+ """
371
+ Compute the number of searchs to do if using -1 as n_search and using
372
+ cartesian or sequential search
373
+
374
+ Args:
375
+ conf (dict): experimental configuration
376
+
377
+ Returns:
378
+ int: size of the cartesian product or length of longest sequential field
379
+ """
380
+ samples = defaultdict(list)
381
+ for k, v in conf.items():
382
+ if not isinstance(v, dict) or "sample" not in v:
383
+ continue
384
+ samples[v["sample"]].append(v)
385
+
386
+ totals = []
387
+
388
+ if "cartesian" in samples:
389
+ total = 1
390
+ for s in samples["cartesian"]:
391
+ total *= len(s["from"])
392
+ totals.append(total)
393
+ if "sequential" in samples:
394
+ total = max(map(len, [s["from"] for s in samples["sequential"]]))
395
+ totals.append(total)
396
+
397
+ if totals:
398
+ return max(totals)
399
+
400
+ raise ValueError(
401
+ "Used n_search=-1 without any field being 'cartesian' or 'sequential'"
402
+ )
403
+
404
+
405
+ def crop_string(s, k=10):
406
+ if len(s) <= k:
407
+ return s
408
+ else:
409
+ return s[: k - 2] + ".."
410
+
411
+
412
+ def sample_param(sample_dict):
413
+ """sample a value (hyperparameter) from the instruction in the
414
+ sample dict:
415
+ {
416
+ "sample": "range | list",
417
+ "from": [min, max, step] | [v0, v1, v2 etc.]
418
+ }
419
+ if range, as np.arange is used, "from" MUST be a list, but may contain
420
+ only 1 (=min) or 2 (min and max) values, not necessarily 3
421
+
422
+ Args:
423
+ sample_dict (dict): instructions to sample a value
424
+
425
+ Returns:
426
+ scalar: sampled value
427
+ """
428
+ if not isinstance(sample_dict, dict) or "sample" not in sample_dict:
429
+ return sample_dict
430
+
431
+ if sample_dict["sample"] == "cartesian":
432
+ assert isinstance(
433
+ sample_dict["from"], list
434
+ ), "{}'s `from` field MUST be a list, found {}".format(
435
+ sample_dict["sample"], sample_dict["from"]
436
+ )
437
+ return "__cartesian__"
438
+
439
+ if sample_dict["sample"] == "sequential":
440
+ assert isinstance(
441
+ sample_dict["from"], list
442
+ ), "{}'s `from` field MUST be a list, found {}".format(
443
+ sample_dict["sample"], sample_dict["from"]
444
+ )
445
+ return "__sequential__"
446
+
447
+ if sample_dict["sample"] == "range":
448
+ return np.random.choice(np.arange(*sample_dict["from"]))
449
+
450
+ if sample_dict["sample"] == "list":
451
+ return np.random.choice(sample_dict["from"])
452
+
453
+ if sample_dict["sample"] == "uniform":
454
+ return np.random.uniform(*sample_dict["from"])
455
+
456
+ raise ValueError("Unknown sample type in dict " + str(sample_dict))
457
+
458
+
459
+ def sample_sequentials(sequential_keys, exp, idx):
460
+ """
461
+ Samples sequentially from the "from" values specified in each key of the
462
+ experimental configuration which have sample == "sequential"
463
+ Unlike `cartesian` sampling, `sequential` sampling iterates *independently*
464
+ over each keys
465
+
466
+ Args:
467
+ sequential_keys (list): keys to be sampled sequentially
468
+ exp (dict): experimental config
469
+ idx (int): index of the current sample
470
+
471
+ Returns:
472
+ conf: sampled dict
473
+ """
474
+ conf = {}
475
+ for k in sequential_keys:
476
+ v = exp[k]["from"]
477
+ conf[k] = v[idx % len(v)]
478
+ return conf
479
+
480
+
481
+ def sample_cartesians(cartesian_keys, exp, idx):
482
+ """
483
+ Returns the `idx`th item in the cartesian product of all cartesian keys to
484
+ be sampled.
485
+
486
+ Args:
487
+ cartesian_keys (list): keys in the experimental configuration that are to
488
+ be used in the full cartesian product
489
+ exp (dict): experimental configuration
490
+ idx (int): index of the current sample
491
+
492
+ Returns:
493
+ dict: sampled point in the cartesian space (with keys = cartesian_keys)
494
+ """
495
+ conf = {}
496
+ cartesian_values = [exp[key]["from"] for key in cartesian_keys]
497
+ product = list(itertools.product(*cartesian_values))
498
+ for k, v in zip(cartesian_keys, product[idx % len(product)]):
499
+ conf[k] = v
500
+ return conf
501
+
502
+
503
+ def resolve(hp_conf, nb):
504
+ """
505
+ Samples parameters parametrized in `exp`: should be a dict with
506
+ values which fit `sample_params(dic)`'s API
507
+
508
+ Args:
509
+ exp (dict): experiment's parametrization
510
+ nb (int): number of experiments to sample
511
+
512
+ Returns:
513
+ dict: sampled configuration
514
+ """
515
+ if nb == -1:
516
+ nb = compute_n_search(hp_conf)
517
+
518
+ confs = []
519
+ for idx in range(nb):
520
+ conf = {}
521
+ cartesians = []
522
+ sequentials = []
523
+ for k, v in hp_conf.items():
524
+ candidate = sample_param(v)
525
+ if candidate == "__cartesian__":
526
+ cartesians.append(k)
527
+ elif candidate == "__sequential__":
528
+ sequentials.append(k)
529
+ else:
530
+ conf[k] = candidate
531
+ if sequentials:
532
+ conf.update(sample_sequentials(sequentials, hp_conf, idx))
533
+ if cartesians:
534
+ conf.update(sample_cartesians(cartesians, hp_conf, idx))
535
+ confs.append(conf)
536
+ return confs
537
+
538
+
539
+ def get_template_params(template):
540
+ """
541
+ extract args in template str as {arg}
542
+
543
+ Args:
544
+ template (str): sbatch template string
545
+
546
+ Returns:
547
+ list(str): Args required to format the template string
548
+ """
549
+ return map(
550
+ lambda s: s.replace("{", "").replace("}", ""),
551
+ re.findall("\{.*?\}", template), # noqa: W605
552
+ )
553
+
554
+
555
+ def read_exp_conf(name):
556
+ """
557
+ Read hp search configuration from shared/experiment/
558
+ specified with or without the .yaml extension
559
+
560
+ Args:
561
+ name (str): name of the template to find in shared/experiment/
562
+
563
+ Returns:
564
+ Tuple(Path, dict): file path and loaded dict
565
+ """
566
+ if ".yaml" not in name:
567
+ name += ".yaml"
568
+ paths = []
569
+ dirs = ["shared", "config"]
570
+ for d in dirs:
571
+ path = Path(__file__).parent / d / "experiment" / name
572
+ if path.exists():
573
+ paths.append(path.resolve())
574
+
575
+ if len(paths) == 0:
576
+ failed = [Path(__file__).parent / d / "experiment" for d in dirs]
577
+ s = "Could not find search config {} in :\n".format(name)
578
+ for fd in failed:
579
+ s += str(fd) + "\nAvailable:\n"
580
+ for ym in fd.glob("*.yaml"):
581
+ s += " " + ym.name + "\n"
582
+ raise ValueError(s)
583
+
584
+ if len(paths) == 2:
585
+ print(
586
+ "Warning: found 2 relevant files for search config:\n{}".format(
587
+ "\n".join(paths)
588
+ )
589
+ )
590
+ print("Using {}".format(paths[-1]))
591
+
592
+ with paths[-1].open("r") as f:
593
+ conf = yaml.safe_load(f)
594
+
595
+ flat_conf = {}
596
+ flatten_conf(conf, to=flat_conf)
597
+
598
+ return (paths[-1], flat_conf)
599
+
600
+
601
+ def read_template(name):
602
+ """
603
+ Read template from shared/template/ specified with or without the .sh extension
604
+
605
+ Args:
606
+ name (str): name of the template to find in shared/template/
607
+
608
+ Returns:
609
+ str: file's content as 1 string
610
+ """
611
+ if ".sh" not in name:
612
+ name += ".sh"
613
+ paths = []
614
+ dirs = ["shared", "config"]
615
+ for d in dirs:
616
+ path = Path(__file__).parent / d / "template" / name
617
+ if path.exists():
618
+ paths.append(path)
619
+
620
+ if len(paths) == 0:
621
+ failed = [Path(__file__).parent / d / "template" for d in dirs]
622
+ s = "Could not find template {} in :\n".format(name)
623
+ for fd in failed:
624
+ s += str(fd) + "\nAvailable:\n"
625
+ for ym in fd.glob("*.sh"):
626
+ s += " " + ym.name + "\n"
627
+ raise ValueError(s)
628
+
629
+ if len(paths) == 2:
630
+ print("Warning: found 2 relevant template files:\n{}".format("\n".join(paths)))
631
+ print("Using {}".format(paths[-1]))
632
+
633
+ with paths[-1].open("r") as f:
634
+ return f.read()
635
+
636
+
637
+ def is_sampled(key, conf):
638
+ """
639
+ Is a key sampled or constant? Returns true if conf is empty
640
+
641
+ Args:
642
+ key (str): key to check
643
+ conf (dict): hyper parameter search configuration dict
644
+
645
+ Returns:
646
+ bool: key is sampled?
647
+ """
648
+ return not conf or (
649
+ key in conf and isinstance(conf[key], dict) and "sample" in conf[key]
650
+ )
651
+
652
+
653
+ if __name__ == "__main__":
654
+
655
+ """
656
+ Notes:
657
+ * Must provide template name as template=name
658
+ * `name`.sh should be in shared/template/
659
+ """
660
+
661
+ # -------------------------------
662
+ # ----- Default Variables -----
663
+ # -------------------------------
664
+
665
+ args = sys.argv[1:]
666
+ command_output = ""
667
+ user = os.environ.get("USER")
668
+ home = os.environ.get("HOME")
669
+ exp_conf = {}
670
+ dev = False
671
+ escape = False
672
+ verbose = False
673
+ template_name = None
674
+ hp_exp_name = None
675
+ hp_search_nb = None
676
+ exp_path = None
677
+ resume = None
678
+ force_sbatchs = False
679
+ sbatch_base = Path(home) / "climategan_sbatchs"
680
+ summary_dir = Path(home) / "climategan_exp_summaries"
681
+
682
+ hp_search_private = set(["n_search", "template", "search", "summary_dir"])
683
+
684
+ sbatch_path = "hash"
685
+
686
+ # --------------------------
687
+ # ----- Sanity Check -----
688
+ # --------------------------
689
+
690
+ for arg in args:
691
+ if "=" not in arg or " = " in arg:
692
+ raise ValueError(
693
+ "Args should be passed as `key=value`. Received `{}`".format(arg)
694
+ )
695
+
696
+ # --------------------------------
697
+ # ----- Parse Command Line -----
698
+ # --------------------------------
699
+
700
+ args_dict = {arg.split("=")[0]: arg.split("=")[1] for arg in args}
701
+
702
+ assert "template" in args_dict, "Please specify template=xxx"
703
+ template = read_template(args_dict["template"])
704
+ template_dict = {k: None for k in get_template_params(template)}
705
+
706
+ train_args = []
707
+ for k, v in args_dict.items():
708
+
709
+ if k == "verbose":
710
+ if v != "0":
711
+ verbose = True
712
+
713
+ elif k == "sbatch_path":
714
+ sbatch_path = v
715
+
716
+ elif k == "sbatch_base":
717
+ sbatch_base = Path(v).resolve()
718
+
719
+ elif k == "force_sbatchs":
720
+ force_sbatchs = v.lower() == "true"
721
+
722
+ elif k == "dev":
723
+ if v.lower() != "false":
724
+ dev = True
725
+
726
+ elif k == "escape":
727
+ if v.lower() != "false":
728
+ escape = True
729
+
730
+ elif k == "template":
731
+ template_name = v
732
+
733
+ elif k == "exp":
734
+ hp_exp_name = v
735
+
736
+ elif k == "n_search":
737
+ hp_search_nb = int(v)
738
+
739
+ elif k == "resume":
740
+ resume = f'"{v}"'
741
+ template_dict[k] = f'"{v}"'
742
+
743
+ elif k == "summary_dir":
744
+ if v.lower() == "none":
745
+ summary_dir = None
746
+ else:
747
+ summary_dir = Path(v)
748
+
749
+ elif k in template_dict:
750
+ template_dict[k] = v
751
+
752
+ else:
753
+ train_args.append(f"{k}={v}")
754
+
755
+ # ------------------------------------
756
+ # ----- Load Experiment Config -----
757
+ # ------------------------------------
758
+
759
+ if hp_exp_name is not None:
760
+ exp_path, exp_conf = read_exp_conf(hp_exp_name)
761
+ if "n_search" in exp_conf and hp_search_nb is None:
762
+ hp_search_nb = exp_conf["n_search"]
763
+
764
+ assert (
765
+ hp_search_nb is not None
766
+ ), "n_search should be specified in a yaml file or from the command line"
767
+
768
+ hps = resolve(exp_conf, hp_search_nb)
769
+
770
+ else:
771
+ hps = [None]
772
+
773
+ # ---------------------------------
774
+ # ----- Run All Experiments -----
775
+ # ---------------------------------
776
+ if summary_dir is not None:
777
+ summary_dir.mkdir(exist_ok=True, parents=True)
778
+ summary = None
779
+
780
+ for hp_idx, hp in enumerate(hps):
781
+
782
+ # copy shared values
783
+ tmp_template_dict = template_dict.copy()
784
+ tmp_train_args = train_args.copy()
785
+ tmp_train_args_dict = {
786
+ arg.split("=")[0]: arg.split("=")[1] for arg in tmp_train_args
787
+ }
788
+ print_header(hp_idx)
789
+ # override shared values with run-specific values for run hp_idx/n_search
790
+ if hp is not None:
791
+ for k, v in hp.items():
792
+ if k == "resume" and resume is None:
793
+ resume = f'"{v}"'
794
+ # hp-search params to ignore
795
+ if k in hp_search_private:
796
+ continue
797
+
798
+ if k == "codeloc":
799
+ v = escape_path(v)
800
+
801
+ if k == "output":
802
+ Path(v).parent.mkdir(parents=True, exist_ok=True)
803
+
804
+ # override template params depending on exp config
805
+ if k in tmp_template_dict:
806
+ if template_dict[k] is None or is_sampled(k, exp_conf):
807
+ tmp_template_dict[k] = v
808
+ # store sampled / specified params in current tmp_train_args_dict
809
+ else:
810
+ if k in tmp_train_args_dict:
811
+ if is_sampled(k, exp_conf):
812
+ # warn if key was specified from the command line
813
+ tv = tmp_train_args_dict[k]
814
+ warn(
815
+ "\nWarning: overriding sampled config-file arg",
816
+ "{} to command-line value {}\n".format(k, tv),
817
+ )
818
+ else:
819
+ tmp_train_args_dict[k] = v
820
+
821
+ # create sbatch file where required
822
+ tmp_sbatch_path = None
823
+ if sbatch_path == "hash":
824
+ tmp_sbatch_name = "" if hp_exp_name is None else hp_exp_name[:14] + "_"
825
+ tmp_sbatch_name += now() + ".sh"
826
+ tmp_sbatch_path = sbatch_base / tmp_sbatch_name
827
+ tmp_sbatch_path.parent.mkdir(parents=True, exist_ok=True)
828
+ tmp_train_args_dict["sbatch_file"] = str(tmp_sbatch_path)
829
+ tmp_train_args_dict["exp_file"] = str(exp_path)
830
+ else:
831
+ tmp_sbatch_path = Path(sbatch_path).resolve()
832
+
833
+ summary = extend_summary(
834
+ summary, tmp_train_args_dict, tmp_template_dict, exclude=["sbatch_file"]
835
+ )
836
+
837
+ # format train.py's args and crop floats' precision to 5 digits
838
+ tmp_template_dict["train_args"] = " ".join(
839
+ sorted(
840
+ [
841
+ "{}={}".format(k, clean_arg(v))
842
+ for k, v in tmp_train_args_dict.items()
843
+ ]
844
+ )
845
+ )
846
+
847
+ if "resume.py" in template and resume is None:
848
+ raise ValueError("No `resume` value but using a resume.py template")
849
+
850
+ # format template with clean dict (replace None with "")
851
+ sbatch = template.format(
852
+ **{
853
+ k: v if v is not None else ""
854
+ for k, v in tmp_template_dict.items()
855
+ if k in template_dict
856
+ }
857
+ )
858
+
859
+ # --------------------------------------
860
+ # ----- Execute `sbatch` Command -----
861
+ # --------------------------------------
862
+ if not dev or force_sbatchs:
863
+ if tmp_sbatch_path.exists():
864
+ print(f"Warning: overwriting {sbatch_path}")
865
+
866
+ # write sbatch file
867
+ with open(tmp_sbatch_path, "w") as f:
868
+ f.write(sbatch)
869
+
870
+ if not dev:
871
+ # escape special characters such as " " from sbatch_path's parent dir
872
+ parent = str(tmp_sbatch_path.parent)
873
+ if escape:
874
+ parent = escape_path(parent)
875
+
876
+ # create command to execute in a subprocess
877
+ command = "sbatch {}".format(tmp_sbatch_path.name)
878
+ # execute sbatch command & store output
879
+ command_output = subprocess.run(
880
+ command.split(), stdout=subprocess.PIPE, cwd=parent
881
+ )
882
+ command_output = "\n" + command_output.stdout.decode("utf-8") + "\n"
883
+
884
+ print(f"Running from {parent}:")
885
+ print(f"$ {command}")
886
+
887
+ # ---------------------------------
888
+ # ----- Summarize Execution -----
889
+ # ---------------------------------
890
+ if verbose:
891
+ print(C.BEIGE + C.ITALIC, "\n" + sbatch + C.ENDC)
892
+ if not dev:
893
+ print_box(command_output.strip())
894
+ jobID = parse_jobID(command_output.strip())
895
+ summary["Slurm JOBID"].append(jobID)
896
+
897
+ summary["Comet Link"].append(f"[{hp_idx}][{hp_idx}]")
898
+
899
+ print(
900
+ "{}{}Summary{} {}:".format(
901
+ C.UNDERLINE,
902
+ C.OKGREEN,
903
+ C.ENDC,
904
+ f"{C.WARNING}(DEV){C.ENDC}" if dev else "",
905
+ )
906
+ )
907
+ print(
908
+ " "
909
+ + "\n ".join(
910
+ "{:10}: {}".format(k, v) for k, v in tmp_template_dict.items()
911
+ )
912
+ )
913
+ print_footer()
914
+
915
+ print(f"\nRan a total of {len(hps)} jobs{' in dev mode.' if dev else '.'}\n")
916
+
917
+ table, sum_path = search_summary_table(summary, summary_dir if not dev else None)
918
+ if table is not None:
919
+ print(table)
920
+ print(
921
+ "Add `[i]: https://...` at the end of a markdown document",
922
+ "to fill in the comet links.\n",
923
+ )
924
+ if summary_dir is None:
925
+ print("Add summary_dir=path to store the printed markdown table ⇪")
926
+ else:
927
+ print("Saved table in", str(sum_path))
928
+
929
+ if not dev:
930
+ print(
931
+ "Cancel entire experiment? \n$ scancel",
932
+ " ".join(map(str, summary["Slurm JOBID"])),
933
+ )