Shivam Singh commited on
Commit
353f374
1 Parent(s): 825f85e

update changes

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_cached_examples/
2
+
3
+ # https://github.com/github/gitignore/blob/main/Python.gitignore
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
166
+
167
+ requirements-dev.lock
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12.2
LICENSE ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Lagon Technologies
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ --------------------------------------------------------------------------------
24
+
25
+ This software contains portions of third party software provided under other licenses:
26
+
27
+ src/esrgan_model.py (x)
28
+ ===================
29
+
30
+ Apache License
31
+ Version 2.0, January 2004
32
+ http://www.apache.org/licenses/
33
+
34
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
35
+
36
+ 1. Definitions.
37
+
38
+ "License" shall mean the terms and conditions for use, reproduction,
39
+ and distribution as defined by Sections 1 through 9 of this document.
40
+
41
+ "Licensor" shall mean the copyright owner or entity authorized by
42
+ the copyright owner that is granting the License.
43
+
44
+ "Legal Entity" shall mean the union of the acting entity and all
45
+ other entities that control, are controlled by, or are under common
46
+ control with that entity. For the purposes of this definition,
47
+ "control" means (i) the power, direct or indirect, to cause the
48
+ direction or management of such entity, whether by contract or
49
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
50
+ outstanding shares, or (iii) beneficial ownership of such entity.
51
+
52
+ "You" (or "Your") shall mean an individual or Legal Entity
53
+ exercising permissions granted by this License.
54
+
55
+ "Source" form shall mean the preferred form for making modifications,
56
+ including but not limited to software source code, documentation
57
+ source, and configuration files.
58
+
59
+ "Object" form shall mean any form resulting from mechanical
60
+ transformation or translation of a Source form, including but
61
+ not limited to compiled object code, generated documentation,
62
+ and conversions to other media types.
63
+
64
+ "Work" shall mean the work of authorship, whether in Source or
65
+ Object form, made available under the License, as indicated by a
66
+ copyright notice that is included in or attached to the work
67
+ (an example is provided in the Appendix below).
68
+
69
+ "Derivative Works" shall mean any work, whether in Source or Object
70
+ form, that is based on (or derived from) the Work and for which the
71
+ editorial revisions, annotations, elaborations, or other modifications
72
+ represent, as a whole, an original work of authorship. For the purposes
73
+ of this License, Derivative Works shall not include works that remain
74
+ separable from, or merely link (or bind by name) to the interfaces of,
75
+ the Work and Derivative Works thereof.
76
+
77
+ "Contribution" shall mean any work of authorship, including
78
+ the original version of the Work and any modifications or additions
79
+ to that Work or Derivative Works thereof, that is intentionally
80
+ submitted to Licensor for inclusion in the Work by the copyright owner
81
+ or by an individual or Legal Entity authorized to submit on behalf of
82
+ the copyright owner. For the purposes of this definition, "submitted"
83
+ means any form of electronic, verbal, or written communication sent
84
+ to the Licensor or its representatives, including but not limited to
85
+ communication on electronic mailing lists, source code control systems,
86
+ and issue tracking systems that are managed by, or on behalf of, the
87
+ Licensor for the purpose of discussing and improving the Work, but
88
+ excluding communication that is conspicuously marked or otherwise
89
+ designated in writing by the copyright owner as "Not a Contribution."
90
+
91
+ "Contributor" shall mean Licensor and any individual or Legal Entity
92
+ on behalf of whom a Contribution has been received by Licensor and
93
+ subsequently incorporated within the Work.
94
+
95
+ 2. Grant of Copyright License. Subject to the terms and conditions of
96
+ this License, each Contributor hereby grants to You a perpetual,
97
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
98
+ copyright license to reproduce, prepare Derivative Works of,
99
+ publicly display, publicly perform, sublicense, and distribute the
100
+ Work and such Derivative Works in Source or Object form.
101
+
102
+ 3. Grant of Patent License. Subject to the terms and conditions of
103
+ this License, each Contributor hereby grants to You a perpetual,
104
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
105
+ (except as stated in this section) patent license to make, have made,
106
+ use, offer to sell, sell, import, and otherwise transfer the Work,
107
+ where such license applies only to those patent claims licensable
108
+ by such Contributor that are necessarily infringed by their
109
+ Contribution(s) alone or by combination of their Contribution(s)
110
+ with the Work to which such Contribution(s) was submitted. If You
111
+ institute patent litigation against any entity (including a
112
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
113
+ or a Contribution incorporated within the Work constitutes direct
114
+ or contributory patent infringement, then any patent licenses
115
+ granted to You under this License for that Work shall terminate
116
+ as of the date such litigation is filed.
117
+
118
+ 4. Redistribution. You may reproduce and distribute copies of the
119
+ Work or Derivative Works thereof in any medium, with or without
120
+ modifications, and in Source or Object form, provided that You
121
+ meet the following conditions:
122
+
123
+ (a) You must give any other recipients of the Work or
124
+ Derivative Works a copy of this License; and
125
+
126
+ (b) You must cause any modified files to carry prominent notices
127
+ stating that You changed the files; and
128
+
129
+ (c) You must retain, in the Source form of any Derivative Works
130
+ that You distribute, all copyright, patent, trademark, and
131
+ attribution notices from the Source form of the Work,
132
+ excluding those notices that do not pertain to any part of
133
+ the Derivative Works; and
134
+
135
+ (d) If the Work includes a "NOTICE" text file as part of its
136
+ distribution, then any Derivative Works that You distribute must
137
+ include a readable copy of the attribution notices contained
138
+ within such NOTICE file, excluding those notices that do not
139
+ pertain to any part of the Derivative Works, in at least one
140
+ of the following places: within a NOTICE text file distributed
141
+ as part of the Derivative Works; within the Source form or
142
+ documentation, if provided along with the Derivative Works; or,
143
+ within a display generated by the Derivative Works, if and
144
+ wherever such third-party notices normally appear. The contents
145
+ of the NOTICE file are for informational purposes only and
146
+ do not modify the License. You may add Your own attribution
147
+ notices within Derivative Works that You distribute, alongside
148
+ or as an addendum to the NOTICE text from the Work, provided
149
+ that such additional attribution notices cannot be construed
150
+ as modifying the License.
151
+
152
+ You may add Your own copyright statement to Your modifications and
153
+ may provide additional or different license terms and conditions
154
+ for use, reproduction, or distribution of Your modifications, or
155
+ for any such Derivative Works as a whole, provided Your use,
156
+ reproduction, and distribution of the Work otherwise complies with
157
+ the conditions stated in this License.
158
+
159
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
160
+ any Contribution intentionally submitted for inclusion in the Work
161
+ by You to the Licensor shall be under the terms and conditions of
162
+ this License, without any additional terms or conditions.
163
+ Notwithstanding the above, nothing herein shall supersede or modify
164
+ the terms of any separate license agreement you may have executed
165
+ with Licensor regarding such Contributions.
166
+
167
+ 6. Trademarks. This License does not grant permission to use the trade
168
+ names, trademarks, service marks, or product names of the Licensor,
169
+ except as required for reasonable and customary use in describing the
170
+ origin of the Work and reproducing the content of the NOTICE file.
171
+
172
+ 7. Disclaimer of Warranty. Unless required by applicable law or
173
+ agreed to in writing, Licensor provides the Work (and each
174
+ Contributor provides its Contributions) on an "AS IS" BASIS,
175
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
176
+ implied, including, without limitation, any warranties or conditions
177
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
178
+ PARTICULAR PURPOSE. You are solely responsible for determining the
179
+ appropriateness of using or redistributing the Work and assume any
180
+ risks associated with Your exercise of permissions under this License.
181
+
182
+ 8. Limitation of Liability. In no event and under no legal theory,
183
+ whether in tort (including negligence), contract, or otherwise,
184
+ unless required by applicable law (such as deliberate and grossly
185
+ negligent acts) or agreed to in writing, shall any Contributor be
186
+ liable to You for damages, including any direct, indirect, special,
187
+ incidental, or consequential damages of any character arising as a
188
+ result of this License or out of the use or inability to use the
189
+ Work (including but not limited to damages for loss of goodwill,
190
+ work stoppage, computer failure or malfunction, or any and all
191
+ other commercial damages or losses), even if such Contributor
192
+ has been advised of the possibility of such damages.
193
+
194
+ 9. Accepting Warranty or Additional Liability. While redistributing
195
+ the Work or Derivative Works thereof, You may choose to offer,
196
+ and charge a fee for, acceptance of support, warranty, indemnity,
197
+ or other liability obligations and/or rights consistent with this
198
+ License. However, in accepting such obligations, You may act only
199
+ on Your own behalf and on Your sole responsibility, not on behalf
200
+ of any other Contributor, and only if You agree to indemnify,
201
+ defend, and hold each Contributor harmless for any liability
202
+ incurred by, or claims asserted against, such Contributor by reason
203
+ of your accepting any such warranty or additional liability.
204
+
205
+ END OF TERMS AND CONDITIONS
206
+
207
+ APPENDIX: How to apply the Apache License to your work.
208
+
209
+ To apply the Apache License to your work, attach the following
210
+ boilerplate notice, with the fields enclosed by brackets "[]"
211
+ replaced with your own identifying information. (Don't include
212
+ the brackets!) The text should be enclosed in the appropriate
213
+ comment syntax for the file format. We also recommend that a
214
+ file or class name and description of purpose be included on the
215
+ same "printed page" as the copyright notice for easier
216
+ identification within third-party archives.
217
+
218
+ Copyright [yyyy] [name of copyright owner]
219
+
220
+ Licensed under the Apache License, Version 2.0 (the "License");
221
+ you may not use this file except in compliance with the License.
222
+ You may obtain a copy of the License at
223
+
224
+ http://www.apache.org/licenses/LICENSE-2.0
225
+
226
+ Unless required by applicable law or agreed to in writing, software
227
+ distributed under the License is distributed on an "AS IS" BASIS,
228
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
229
+ See the License for the specific language governing permissions and
230
+ limitations under the License.
231
+
232
+ (x) Modified from https://github.com/philz1337x/clarity-upscaler
233
+ which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui
234
+ which is a copy of https://github.com/victorca25/iNNfer
235
+ which is a copy of https://github.com/xinntao/ESRGAN
examples/clarity_bird.webp ADDED

Git LFS Details

  • SHA256: a1bf18d88b928ba178dda6c773e4e3327c08f17c743c8f124ab03f3ef7100a65
  • Pointer size: 130 Bytes
  • Size of remote file: 42.5 kB
examples/edgar-infocus-gJH8AqpiSEU-unsplash.jpg ADDED

Git LFS Details

  • SHA256: d458cb591d83eaed54f406f0ff625a640ed3c72c02fc4728f66ceb5cd354cc0b
  • Pointer size: 130 Bytes
  • Size of remote file: 49.4 kB
examples/jeremy-wallace-_XjW3oN8UOE-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 04f17db7d49fd915237c7721f2723b43b4a0c53acfd737c3742b864993aac71e
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
examples/kara-eads-L7EwHkq1B2s-unsplash.jpg ADDED

Git LFS Details

  • SHA256: f3c6c772b5ef805f9b317d3b9547940a990ca8c983689ef553dd76a8be49a398
  • Pointer size: 130 Bytes
  • Size of remote file: 66.5 kB
examples/karina-vorozheeva-rW-I87aPY5Y-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 5e16ee5b0aae24133be45e7685dbabebbfafc17dfe08a7ceba3fc5b0f42a66dc
  • Pointer size: 130 Bytes
  • Size of remote file: 90 kB
examples/karographix-photography-hIaOPjYCEj4-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 83860b586d31cf981df74b7579bbcef9eede87d294b47d3a88211bbc9af25501
  • Pointer size: 130 Bytes
  • Size of remote file: 66 kB
examples/melissa-walker-horn-gtDYwUIr9Vg-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 1d1a7d7f5c5fff1ed335e0e1998f86fe2e4c20565f5a95dac72dc279cb60dae4
  • Pointer size: 130 Bytes
  • Size of remote file: 81.1 kB
examples/ryoji-iwata-X53e51WfjlE-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 4ad5c2789d05d1b7e4727d820e5bc8c9707e76e2f10408966970752acf398e36
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
examples/tadeusz-lakota-jggQZkITXng-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 128df6ad73db9eaa6fa9befe9eb43c91d500da4cedbe6ec4188f3ba336beaf7c
  • Pointer size: 130 Bytes
  • Size of remote file: 64.7 kB
pyproject.toml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "enhancer"
3
+ version = "0.1.0"
4
+ description = "Finegrain Image Enhancer"
5
+ authors = [
6
+ { name = "Laurent Fainsin", email = "laurent@lagon.tech" }
7
+ ]
8
+ dependencies = [
9
+ "gradio>=4.42.0",
10
+ "pillow>=10.4.0",
11
+ "gradio-imageslider>=0.0.20",
12
+ "pillow-heif>=0.18.0",
13
+ "refiners @ git+https://github.com/finegrain-ai/refiners",
14
+ "spaces>=0.29.3",
15
+ "numpy<2.0.0",
16
+ ]
17
+ readme = "README.md"
18
+ requires-python = ">= 3.12, <3.13"
19
+
20
+ [build-system]
21
+ requires = ["hatchling"]
22
+ build-backend = "hatchling.build"
23
+
24
+ [tool.rye]
25
+ managed = true
26
+ dev-dependencies = []
27
+
28
+ [tool.hatch.metadata]
29
+ allow-direct-references = true
30
+
31
+ [tool.hatch.build.targets.wheel]
32
+ packages = ["src/enhancer"]
33
+
34
+ [tool.ruff]
35
+ src = ["src"] # https://docs.astral.sh/ruff/settings/#src
36
+ exclude = ["esrgan_model.py"]
37
+ line-length = 120
38
+ target-version = "py312"
39
+
40
+ [tool.ruff.lint]
41
+ select = [
42
+ "E", # pycodestyle errors
43
+ "W", # pycodestyle warnings
44
+ "F", # pyflakes
45
+ "UP", # pyupgrade
46
+ "A", # flake8-builtins
47
+ "B", # flake8-bugbear
48
+ "Q", # flake8-quotes
49
+ "I", # isort
50
+ ]
51
+
52
+ [tool.pyright]
53
+ include = ["src"]
54
+ exclude = ["**/__pycache__"]
requirements.lock ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generated by rye
2
+ # use `rye lock` or `rye sync` to update this lockfile
3
+ #
4
+ # last locked with the following flags:
5
+ # pre: false
6
+ # features: []
7
+ # all-features: false
8
+ # with-sources: false
9
+ # generate-hashes: false
10
+ # universal: false
11
+
12
+ -e file:.
13
+ aiofiles==23.2.1
14
+ # via gradio
15
+ annotated-types==0.7.0
16
+ # via pydantic
17
+ anyio==4.4.0
18
+ # via gradio
19
+ # via httpx
20
+ # via starlette
21
+ certifi==2024.8.30
22
+ # via httpcore
23
+ # via httpx
24
+ # via requests
25
+ charset-normalizer==3.3.2
26
+ # via requests
27
+ click==8.1.7
28
+ # via typer
29
+ # via uvicorn
30
+ contourpy==1.3.0
31
+ # via matplotlib
32
+ cycler==0.12.1
33
+ # via matplotlib
34
+ fastapi==0.113.0
35
+ # via gradio
36
+ ffmpy==0.4.0
37
+ # via gradio
38
+ filelock==3.15.4
39
+ # via huggingface-hub
40
+ # via torch
41
+ # via triton
42
+ fonttools==4.53.1
43
+ # via matplotlib
44
+ fsspec==2024.9.0
45
+ # via gradio-client
46
+ # via huggingface-hub
47
+ # via torch
48
+ gradio==4.42.0
49
+ # via enhancer
50
+ # via gradio-imageslider
51
+ # via spaces
52
+ gradio-client==1.3.0
53
+ # via gradio
54
+ gradio-imageslider==0.0.20
55
+ # via enhancer
56
+ h11==0.14.0
57
+ # via httpcore
58
+ # via uvicorn
59
+ httpcore==1.0.5
60
+ # via httpx
61
+ httpx==0.27.2
62
+ # via gradio
63
+ # via gradio-client
64
+ # via spaces
65
+ huggingface-hub==0.24.6
66
+ # via gradio
67
+ # via gradio-client
68
+ idna==3.8
69
+ # via anyio
70
+ # via httpx
71
+ # via requests
72
+ importlib-resources==6.4.4
73
+ # via gradio
74
+ jaxtyping==0.2.34
75
+ # via refiners
76
+ jinja2==3.1.4
77
+ # via gradio
78
+ # via torch
79
+ kiwisolver==1.4.7
80
+ # via matplotlib
81
+ markdown-it-py==3.0.0
82
+ # via rich
83
+ markupsafe==2.1.5
84
+ # via gradio
85
+ # via jinja2
86
+ matplotlib==3.9.2
87
+ # via gradio
88
+ mdurl==0.1.2
89
+ # via markdown-it-py
90
+ mpmath==1.3.0
91
+ # via sympy
92
+ networkx==3.3
93
+ # via torch
94
+ numpy==1.26.4
95
+ # via contourpy
96
+ # via enhancer
97
+ # via gradio
98
+ # via matplotlib
99
+ # via pandas
100
+ # via refiners
101
+ nvidia-cublas-cu12==12.1.3.1
102
+ # via nvidia-cudnn-cu12
103
+ # via nvidia-cusolver-cu12
104
+ # via torch
105
+ nvidia-cuda-cupti-cu12==12.1.105
106
+ # via torch
107
+ nvidia-cuda-nvrtc-cu12==12.1.105
108
+ # via torch
109
+ nvidia-cuda-runtime-cu12==12.1.105
110
+ # via torch
111
+ nvidia-cudnn-cu12==9.1.0.70
112
+ # via torch
113
+ nvidia-cufft-cu12==11.0.2.54
114
+ # via torch
115
+ nvidia-curand-cu12==10.3.2.106
116
+ # via torch
117
+ nvidia-cusolver-cu12==11.4.5.107
118
+ # via torch
119
+ nvidia-cusparse-cu12==12.1.0.106
120
+ # via nvidia-cusolver-cu12
121
+ # via torch
122
+ nvidia-nccl-cu12==2.20.5
123
+ # via torch
124
+ nvidia-nvjitlink-cu12==12.6.68
125
+ # via nvidia-cusolver-cu12
126
+ # via nvidia-cusparse-cu12
127
+ nvidia-nvtx-cu12==12.1.105
128
+ # via torch
129
+ orjson==3.10.7
130
+ # via gradio
131
+ packaging==24.1
132
+ # via gradio
133
+ # via gradio-client
134
+ # via huggingface-hub
135
+ # via matplotlib
136
+ # via refiners
137
+ # via spaces
138
+ pandas==2.2.2
139
+ # via gradio
140
+ pillow==10.4.0
141
+ # via enhancer
142
+ # via gradio
143
+ # via gradio-imageslider
144
+ # via matplotlib
145
+ # via pillow-heif
146
+ # via refiners
147
+ pillow-heif==0.18.0
148
+ # via enhancer
149
+ psutil==5.9.8
150
+ # via spaces
151
+ pydantic==2.8.2
152
+ # via fastapi
153
+ # via gradio
154
+ # via spaces
155
+ pydantic-core==2.20.1
156
+ # via pydantic
157
+ pydub==0.25.1
158
+ # via gradio
159
+ pygments==2.18.0
160
+ # via rich
161
+ pyparsing==3.1.4
162
+ # via matplotlib
163
+ python-dateutil==2.9.0.post0
164
+ # via matplotlib
165
+ # via pandas
166
+ python-multipart==0.0.9
167
+ # via gradio
168
+ pytz==2024.1
169
+ # via pandas
170
+ pyyaml==6.0.2
171
+ # via gradio
172
+ # via huggingface-hub
173
+ refiners @ git+https://github.com/finegrain-ai/refiners@cf247a1b20609479565618f49bf70c8aa65a7cfd
174
+ # via enhancer
175
+ requests==2.32.3
176
+ # via huggingface-hub
177
+ # via spaces
178
+ rich==13.8.0
179
+ # via typer
180
+ ruff==0.6.4
181
+ # via gradio
182
+ safetensors==0.4.5
183
+ # via refiners
184
+ semantic-version==2.10.0
185
+ # via gradio
186
+ setuptools==74.1.2
187
+ # via torch
188
+ shellingham==1.5.4
189
+ # via typer
190
+ six==1.16.0
191
+ # via python-dateutil
192
+ sniffio==1.3.1
193
+ # via anyio
194
+ # via httpx
195
+ spaces==0.30.1
196
+ # via enhancer
197
+ starlette==0.38.4
198
+ # via fastapi
199
+ sympy==1.13.2
200
+ # via torch
201
+ tomlkit==0.12.0
202
+ # via gradio
203
+ torch==2.4.1
204
+ # via refiners
205
+ tqdm==4.66.5
206
+ # via huggingface-hub
207
+ triton==3.0.0
208
+ # via torch
209
+ typeguard==2.13.3
210
+ # via jaxtyping
211
+ typer==0.12.5
212
+ # via gradio
213
+ typing-extensions==4.12.2
214
+ # via fastapi
215
+ # via gradio
216
+ # via gradio-client
217
+ # via huggingface-hub
218
+ # via pydantic
219
+ # via pydantic-core
220
+ # via spaces
221
+ # via torch
222
+ # via typer
223
+ tzdata==2024.1
224
+ # via pandas
225
+ urllib3==2.2.2
226
+ # via gradio
227
+ # via requests
228
+ uvicorn==0.30.6
229
+ # via gradio
230
+ websockets==12.0
231
+ # via gradio-client
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/finegrain-ai/refiners@a5d3c2971b84f6faa4762b1cf5a07f4f812bb1f5
2
+ gradio_imageslider==0.0.20
3
+ spaces==0.28.3
4
+ numpy<2.0.0
5
+ pillow>=10.4.0
6
+ pillow-heif>=0.18.0
src/app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio as gr
4
+ import pillow_heif
5
+ import spaces
6
+ import torch
7
+ from gradio_imageslider import ImageSlider
8
+ from huggingface_hub import hf_hub_download
9
+ from PIL import Image
10
+ from refiners.fluxion.utils import manual_seed
11
+ from refiners.foundationals.latent_diffusion import Solver, solvers
12
+
13
+ from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
14
+
15
+ pillow_heif.register_heif_opener()
16
+ pillow_heif.register_avif_opener()
17
+
18
+ TITLE = """
19
+ <center>
20
+
21
+ <h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
22
+ Image Enhancer Powered By Refiners
23
+ </h1>
24
+
25
+ <div style="
26
+ display: flex;
27
+ align-items: center;
28
+ justify-content: center;
29
+ gap: 0.5rem;
30
+ margin-bottom: 0.5rem;
31
+ font-size: 1.25rem;
32
+ flex-wrap: wrap;
33
+ ">
34
+ <a href="https://blog.finegrain.ai/posts/reproducing-clarity-upscaler/" target="_blank">[Blog Post]</a>
35
+ <a href="https://github.com/finegrain-ai/refiners" target="_blank">[Refiners]</a>
36
+ <a href="https://finegrain.ai/" target="_blank">[Finegrain]</a>
37
+ <a href="https://huggingface.co/spaces/finegrain/finegrain-object-eraser" target="_blank">
38
+ [Finegrain Object Eraser]
39
+ </a>
40
+ <a href="https://huggingface.co/spaces/finegrain/finegrain-object-cutter" target="_blank">
41
+ [Finegrain Object Cutter]
42
+ </a>
43
+ </div>
44
+
45
+ <p>
46
+ Turn low resolution images into high resolution versions with added generated details (your image will be modified).
47
+ </p>
48
+
49
+ <p>
50
+ This space is powered by Refiners, our open source micro-framework for simple foundation model adaptation.
51
+ If you enjoyed it, please consider starring Refiners on GitHub!
52
+ </p>
53
+
54
+ <a href="https://github.com/finegrain-ai/refiners" target="_blank">
55
+ <img src="https://img.shields.io/github/stars/finegrain-ai/refiners?style=social" />
56
+ </a>
57
+
58
+ </center>
59
+ """
60
+
61
+ CHECKPOINTS = ESRGANUpscalerCheckpoints(
62
+ unet=Path(
63
+ hf_hub_download(
64
+ repo_id="refiners/juggernaut.reborn.sd1_5.unet",
65
+ filename="model.safetensors",
66
+ revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2",
67
+ )
68
+ ),
69
+ clip_text_encoder=Path(
70
+ hf_hub_download(
71
+ repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder",
72
+ filename="model.safetensors",
73
+ revision="744ad6a5c0437ec02ad826df9f6ede102bb27481",
74
+ )
75
+ ),
76
+ lda=Path(
77
+ hf_hub_download(
78
+ repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder",
79
+ filename="model.safetensors",
80
+ revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19",
81
+ )
82
+ ),
83
+ controlnet_tile=Path(
84
+ hf_hub_download(
85
+ repo_id="refiners/controlnet.sd1_5.tile",
86
+ filename="model.safetensors",
87
+ revision="48ced6ff8bfa873a8976fa467c3629a240643387",
88
+ )
89
+ ),
90
+ esrgan=Path(
91
+ hf_hub_download(
92
+ repo_id="philz1337x/upscaler",
93
+ filename="4x-UltraSharp.pth",
94
+ revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
95
+ )
96
+ ),
97
+ negative_embedding=Path(
98
+ hf_hub_download(
99
+ repo_id="philz1337x/embeddings",
100
+ filename="JuggernautNegative-neg.pt",
101
+ revision="203caa7e9cc2bc225031a4021f6ab1ded283454a",
102
+ )
103
+ ),
104
+ negative_embedding_key="string_to_param.*",
105
+ loras={
106
+ "more_details": Path(
107
+ hf_hub_download(
108
+ repo_id="philz1337x/loras",
109
+ filename="more_details.safetensors",
110
+ revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
111
+ )
112
+ ),
113
+ "sdxl_render": Path(
114
+ hf_hub_download(
115
+ repo_id="philz1337x/loras",
116
+ filename="SDXLrender_v2.0.safetensors",
117
+ revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
118
+ )
119
+ ),
120
+ },
121
+ )
122
+
123
+ # initialize the enhancer, on the cpu
124
+ DEVICE_CPU = torch.device("cpu")
125
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
126
+ enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE_CPU, dtype=DTYPE)
127
+
128
+ # "move" the enhancer to the gpu, this is handled by Zero GPU
129
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
+ enhancer.to(device=DEVICE, dtype=DTYPE)
131
+
132
+
133
+ @spaces.GPU
134
+ def process(
135
+ input_image: Image.Image,
136
+ prompt: str = "masterpiece, best quality, highres",
137
+ negative_prompt: str = "worst quality, low quality, normal quality",
138
+ seed: int = 42,
139
+ upscale_factor: int = 2,
140
+ controlnet_scale: float = 0.6,
141
+ controlnet_decay: float = 1.0,
142
+ condition_scale: int = 6,
143
+ tile_width: int = 112,
144
+ tile_height: int = 144,
145
+ denoise_strength: float = 0.35,
146
+ num_inference_steps: int = 18,
147
+ solver: str = "DDIM",
148
+ ) -> tuple[Image.Image, Image.Image]:
149
+ manual_seed(seed)
150
+
151
+ solver_type: type[Solver] = getattr(solvers, solver)
152
+
153
+ enhanced_image = enhancer.upscale(
154
+ image=input_image,
155
+ prompt=prompt,
156
+ negative_prompt=negative_prompt,
157
+ upscale_factor=upscale_factor,
158
+ controlnet_scale=controlnet_scale,
159
+ controlnet_scale_decay=controlnet_decay,
160
+ condition_scale=condition_scale,
161
+ tile_size=(tile_height, tile_width),
162
+ denoise_strength=denoise_strength,
163
+ num_inference_steps=num_inference_steps,
164
+ loras_scale={"more_details": 0.5, "sdxl_render": 1.0},
165
+ solver_type=solver_type,
166
+ )
167
+
168
+ return (input_image, enhanced_image)
169
+
170
+
171
+ with gr.Blocks() as demo:
172
+ gr.HTML(TITLE)
173
+
174
+ with gr.Row():
175
+ with gr.Column():
176
+ input_image = gr.Image(type="pil", label="Input Image")
177
+ run_button = gr.ClearButton(components=None, value="Enhance Image")
178
+ with gr.Column():
179
+ output_slider = ImageSlider(label="Before / After")
180
+ run_button.add(output_slider)
181
+
182
+ with gr.Accordion("Advanced Options", open=False):
183
+ prompt = gr.Textbox(
184
+ label="Prompt",
185
+ placeholder="masterpiece, best quality, highres",
186
+ )
187
+ negative_prompt = gr.Textbox(
188
+ label="Negative Prompt",
189
+ placeholder="worst quality, low quality, normal quality",
190
+ )
191
+ seed = gr.Slider(
192
+ minimum=0,
193
+ maximum=10_000,
194
+ value=42,
195
+ step=1,
196
+ label="Seed",
197
+ )
198
+ upscale_factor = gr.Slider(
199
+ minimum=1,
200
+ maximum=4,
201
+ value=2,
202
+ step=0.2,
203
+ label="Upscale Factor",
204
+ )
205
+ controlnet_scale = gr.Slider(
206
+ minimum=0,
207
+ maximum=1.5,
208
+ value=0.6,
209
+ step=0.1,
210
+ label="ControlNet Scale",
211
+ )
212
+ controlnet_decay = gr.Slider(
213
+ minimum=0.5,
214
+ maximum=1,
215
+ value=1.0,
216
+ step=0.025,
217
+ label="ControlNet Scale Decay",
218
+ )
219
+ condition_scale = gr.Slider(
220
+ minimum=2,
221
+ maximum=20,
222
+ value=6,
223
+ step=1,
224
+ label="Condition Scale",
225
+ )
226
+ tile_width = gr.Slider(
227
+ minimum=64,
228
+ maximum=200,
229
+ value=112,
230
+ step=1,
231
+ label="Latent Tile Width",
232
+ )
233
+ tile_height = gr.Slider(
234
+ minimum=64,
235
+ maximum=200,
236
+ value=144,
237
+ step=1,
238
+ label="Latent Tile Height",
239
+ )
240
+ denoise_strength = gr.Slider(
241
+ minimum=0,
242
+ maximum=1,
243
+ value=0.35,
244
+ step=0.1,
245
+ label="Denoise Strength",
246
+ )
247
+ num_inference_steps = gr.Slider(
248
+ minimum=1,
249
+ maximum=30,
250
+ value=18,
251
+ step=1,
252
+ label="Number of Inference Steps",
253
+ )
254
+ solver = gr.Radio(
255
+ choices=["DDIM", "DPMSolver"],
256
+ value="DDIM",
257
+ label="Solver",
258
+ )
259
+
260
+ run_button.click(
261
+ fn=process,
262
+ inputs=[
263
+ input_image,
264
+ prompt,
265
+ negative_prompt,
266
+ seed,
267
+ upscale_factor,
268
+ controlnet_scale,
269
+ controlnet_decay,
270
+ condition_scale,
271
+ tile_width,
272
+ tile_height,
273
+ denoise_strength,
274
+ num_inference_steps,
275
+ solver,
276
+ ],
277
+ outputs=output_slider,
278
+ )
279
+
280
+ gr.Examples(
281
+ examples=[
282
+ "examples/kara-eads-L7EwHkq1B2s-unsplash.jpg",
283
+ "examples/clarity_bird.webp",
284
+ "examples/edgar-infocus-gJH8AqpiSEU-unsplash.jpg",
285
+ "examples/jeremy-wallace-_XjW3oN8UOE-unsplash.jpg",
286
+ "examples/karina-vorozheeva-rW-I87aPY5Y-unsplash.jpg",
287
+ "examples/karographix-photography-hIaOPjYCEj4-unsplash.jpg",
288
+ "examples/melissa-walker-horn-gtDYwUIr9Vg-unsplash.jpg",
289
+ "examples/ryoji-iwata-X53e51WfjlE-unsplash.jpg",
290
+ "examples/tadeusz-lakota-jggQZkITXng-unsplash.jpg",
291
+ ],
292
+ inputs=[input_image],
293
+ outputs=output_slider,
294
+ fn=process,
295
+ cache_examples="lazy",
296
+ run_on_click=False,
297
+ )
298
+
299
+ demo.launch(share=False)
src/enhancer.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
8
+ MultiUpscaler,
9
+ UpscalerCheckpoints,
10
+ )
11
+
12
+ from esrgan_model import UpscalerESRGAN
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
17
+ esrgan: Path
18
+
19
+
20
+ class ESRGANUpscaler(MultiUpscaler):
21
+ def __init__(
22
+ self,
23
+ checkpoints: ESRGANUpscalerCheckpoints,
24
+ device: torch.device,
25
+ dtype: torch.dtype,
26
+ ) -> None:
27
+ super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
28
+ self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
29
+ self.esrgan.to(device=device, dtype=dtype)
30
+
31
+ def to(self, device: torch.device, dtype: torch.dtype):
32
+ self.esrgan.to(device=device, dtype=dtype)
33
+ self.sd = self.sd.to(device=device, dtype=dtype)
34
+ self.device = device
35
+ self.dtype = dtype
36
+
37
+ def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
38
+ image = self.esrgan.upscale_with_tiling(image)
39
+ return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)
src/esrgan_model.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ """
3
+ Modified from https://github.com/philz1337x/clarity-upscaler
4
+ which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui
5
+ which is a copy of https://github.com/victorca25/iNNfer
6
+ which is a copy of https://github.com/xinntao/ESRGAN
7
+ """
8
+
9
+ import math
10
+ import os
11
+ from collections import OrderedDict, namedtuple
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from PIL import Image
19
+
20
+ ####################
21
+ # RRDBNet Generator
22
+ ####################
23
+
24
+
25
+ class RRDBNet(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_nc,
29
+ out_nc,
30
+ nf,
31
+ nb,
32
+ nr=3,
33
+ gc=32,
34
+ upscale=4,
35
+ norm_type=None,
36
+ act_type="leakyrelu",
37
+ mode="CNA",
38
+ upsample_mode="upconv",
39
+ convtype="Conv2D",
40
+ finalact=None,
41
+ gaussian_noise=False,
42
+ plus=False,
43
+ ):
44
+ super(RRDBNet, self).__init__()
45
+ n_upscale = int(math.log(upscale, 2))
46
+ if upscale == 3:
47
+ n_upscale = 1
48
+
49
+ self.resrgan_scale = 0
50
+ if in_nc % 16 == 0:
51
+ self.resrgan_scale = 1
52
+ elif in_nc != 4 and in_nc % 4 == 0:
53
+ self.resrgan_scale = 2
54
+
55
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
56
+ rb_blocks = [
57
+ RRDB(
58
+ nf,
59
+ nr,
60
+ kernel_size=3,
61
+ gc=32,
62
+ stride=1,
63
+ bias=1,
64
+ pad_type="zero",
65
+ norm_type=norm_type,
66
+ act_type=act_type,
67
+ mode="CNA",
68
+ convtype=convtype,
69
+ gaussian_noise=gaussian_noise,
70
+ plus=plus,
71
+ )
72
+ for _ in range(nb)
73
+ ]
74
+ LR_conv = conv_block(
75
+ nf,
76
+ nf,
77
+ kernel_size=3,
78
+ norm_type=norm_type,
79
+ act_type=None,
80
+ mode=mode,
81
+ convtype=convtype,
82
+ )
83
+
84
+ if upsample_mode == "upconv":
85
+ upsample_block = upconv_block
86
+ elif upsample_mode == "pixelshuffle":
87
+ upsample_block = pixelshuffle_block
88
+ else:
89
+ raise NotImplementedError(f"upsample mode [{upsample_mode}] is not found")
90
+ if upscale == 3:
91
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
92
+ else:
93
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
94
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
95
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
96
+
97
+ outact = act(finalact) if finalact else None
98
+
99
+ self.model = sequential(
100
+ fea_conv,
101
+ ShortcutBlock(sequential(*rb_blocks, LR_conv)),
102
+ *upsampler,
103
+ HR_conv0,
104
+ HR_conv1,
105
+ outact,
106
+ )
107
+
108
+ def forward(self, x, outm=None):
109
+ if self.resrgan_scale == 1:
110
+ feat = pixel_unshuffle(x, scale=4)
111
+ elif self.resrgan_scale == 2:
112
+ feat = pixel_unshuffle(x, scale=2)
113
+ else:
114
+ feat = x
115
+
116
+ return self.model(feat)
117
+
118
+
119
+ class RRDB(nn.Module):
120
+ """
121
+ Residual in Residual Dense Block
122
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ nf,
128
+ nr=3,
129
+ kernel_size=3,
130
+ gc=32,
131
+ stride=1,
132
+ bias=1,
133
+ pad_type="zero",
134
+ norm_type=None,
135
+ act_type="leakyrelu",
136
+ mode="CNA",
137
+ convtype="Conv2D",
138
+ spectral_norm=False,
139
+ gaussian_noise=False,
140
+ plus=False,
141
+ ):
142
+ super(RRDB, self).__init__()
143
+ # This is for backwards compatibility with existing models
144
+ if nr == 3:
145
+ self.RDB1 = ResidualDenseBlock_5C(
146
+ nf,
147
+ kernel_size,
148
+ gc,
149
+ stride,
150
+ bias,
151
+ pad_type,
152
+ norm_type,
153
+ act_type,
154
+ mode,
155
+ convtype,
156
+ spectral_norm=spectral_norm,
157
+ gaussian_noise=gaussian_noise,
158
+ plus=plus,
159
+ )
160
+ self.RDB2 = ResidualDenseBlock_5C(
161
+ nf,
162
+ kernel_size,
163
+ gc,
164
+ stride,
165
+ bias,
166
+ pad_type,
167
+ norm_type,
168
+ act_type,
169
+ mode,
170
+ convtype,
171
+ spectral_norm=spectral_norm,
172
+ gaussian_noise=gaussian_noise,
173
+ plus=plus,
174
+ )
175
+ self.RDB3 = ResidualDenseBlock_5C(
176
+ nf,
177
+ kernel_size,
178
+ gc,
179
+ stride,
180
+ bias,
181
+ pad_type,
182
+ norm_type,
183
+ act_type,
184
+ mode,
185
+ convtype,
186
+ spectral_norm=spectral_norm,
187
+ gaussian_noise=gaussian_noise,
188
+ plus=plus,
189
+ )
190
+ else:
191
+ RDB_list = [
192
+ ResidualDenseBlock_5C(
193
+ nf,
194
+ kernel_size,
195
+ gc,
196
+ stride,
197
+ bias,
198
+ pad_type,
199
+ norm_type,
200
+ act_type,
201
+ mode,
202
+ convtype,
203
+ spectral_norm=spectral_norm,
204
+ gaussian_noise=gaussian_noise,
205
+ plus=plus,
206
+ )
207
+ for _ in range(nr)
208
+ ]
209
+ self.RDBs = nn.Sequential(*RDB_list)
210
+
211
+ def forward(self, x):
212
+ if hasattr(self, "RDB1"):
213
+ out = self.RDB1(x)
214
+ out = self.RDB2(out)
215
+ out = self.RDB3(out)
216
+ else:
217
+ out = self.RDBs(x)
218
+ return out * 0.2 + x
219
+
220
+
221
+ class ResidualDenseBlock_5C(nn.Module):
222
+ """
223
+ Residual Dense Block
224
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
225
+ Modified options that can be used:
226
+ - "Partial Convolution based Padding" arXiv:1811.11718
227
+ - "Spectral normalization" arXiv:1802.05957
228
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
229
+ {Rakotonirina} and A. {Rasoanaivo}
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ nf=64,
235
+ kernel_size=3,
236
+ gc=32,
237
+ stride=1,
238
+ bias=1,
239
+ pad_type="zero",
240
+ norm_type=None,
241
+ act_type="leakyrelu",
242
+ mode="CNA",
243
+ convtype="Conv2D",
244
+ spectral_norm=False,
245
+ gaussian_noise=False,
246
+ plus=False,
247
+ ):
248
+ super(ResidualDenseBlock_5C, self).__init__()
249
+
250
+ self.noise = GaussianNoise() if gaussian_noise else None
251
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
252
+
253
+ self.conv1 = conv_block(
254
+ nf,
255
+ gc,
256
+ kernel_size,
257
+ stride,
258
+ bias=bias,
259
+ pad_type=pad_type,
260
+ norm_type=norm_type,
261
+ act_type=act_type,
262
+ mode=mode,
263
+ convtype=convtype,
264
+ spectral_norm=spectral_norm,
265
+ )
266
+ self.conv2 = conv_block(
267
+ nf + gc,
268
+ gc,
269
+ kernel_size,
270
+ stride,
271
+ bias=bias,
272
+ pad_type=pad_type,
273
+ norm_type=norm_type,
274
+ act_type=act_type,
275
+ mode=mode,
276
+ convtype=convtype,
277
+ spectral_norm=spectral_norm,
278
+ )
279
+ self.conv3 = conv_block(
280
+ nf + 2 * gc,
281
+ gc,
282
+ kernel_size,
283
+ stride,
284
+ bias=bias,
285
+ pad_type=pad_type,
286
+ norm_type=norm_type,
287
+ act_type=act_type,
288
+ mode=mode,
289
+ convtype=convtype,
290
+ spectral_norm=spectral_norm,
291
+ )
292
+ self.conv4 = conv_block(
293
+ nf + 3 * gc,
294
+ gc,
295
+ kernel_size,
296
+ stride,
297
+ bias=bias,
298
+ pad_type=pad_type,
299
+ norm_type=norm_type,
300
+ act_type=act_type,
301
+ mode=mode,
302
+ convtype=convtype,
303
+ spectral_norm=spectral_norm,
304
+ )
305
+ if mode == "CNA":
306
+ last_act = None
307
+ else:
308
+ last_act = act_type
309
+ self.conv5 = conv_block(
310
+ nf + 4 * gc,
311
+ nf,
312
+ 3,
313
+ stride,
314
+ bias=bias,
315
+ pad_type=pad_type,
316
+ norm_type=norm_type,
317
+ act_type=last_act,
318
+ mode=mode,
319
+ convtype=convtype,
320
+ spectral_norm=spectral_norm,
321
+ )
322
+
323
+ def forward(self, x):
324
+ x1 = self.conv1(x)
325
+ x2 = self.conv2(torch.cat((x, x1), 1))
326
+ if self.conv1x1:
327
+ x2 = x2 + self.conv1x1(x)
328
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
329
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
330
+ if self.conv1x1:
331
+ x4 = x4 + x2
332
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
333
+ if self.noise:
334
+ return self.noise(x5.mul(0.2) + x)
335
+ else:
336
+ return x5 * 0.2 + x
337
+
338
+
339
+ ####################
340
+ # ESRGANplus
341
+ ####################
342
+
343
+
344
+ class GaussianNoise(nn.Module):
345
+ def __init__(self, sigma=0.1, is_relative_detach=False):
346
+ super().__init__()
347
+ self.sigma = sigma
348
+ self.is_relative_detach = is_relative_detach
349
+ self.noise = torch.tensor(0, dtype=torch.float)
350
+
351
+ def forward(self, x):
352
+ if self.training and self.sigma != 0:
353
+ self.noise = self.noise.to(device=x.device, dtype=x.device)
354
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
355
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
356
+ x = x + sampled_noise
357
+ return x
358
+
359
+
360
+ def conv1x1(in_planes, out_planes, stride=1):
361
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
362
+
363
+
364
+ ####################
365
+ # SRVGGNetCompact
366
+ ####################
367
+
368
+
369
+ class SRVGGNetCompact(nn.Module):
370
+ """A compact VGG-style network structure for super-resolution.
371
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ num_in_ch=3,
377
+ num_out_ch=3,
378
+ num_feat=64,
379
+ num_conv=16,
380
+ upscale=4,
381
+ act_type="prelu",
382
+ ):
383
+ super(SRVGGNetCompact, self).__init__()
384
+ self.num_in_ch = num_in_ch
385
+ self.num_out_ch = num_out_ch
386
+ self.num_feat = num_feat
387
+ self.num_conv = num_conv
388
+ self.upscale = upscale
389
+ self.act_type = act_type
390
+
391
+ self.body = nn.ModuleList()
392
+ # the first conv
393
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
394
+ # the first activation
395
+ if act_type == "relu":
396
+ activation = nn.ReLU(inplace=True)
397
+ elif act_type == "prelu":
398
+ activation = nn.PReLU(num_parameters=num_feat)
399
+ elif act_type == "leakyrelu":
400
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
401
+ self.body.append(activation)
402
+
403
+ # the body structure
404
+ for _ in range(num_conv):
405
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
406
+ # activation
407
+ if act_type == "relu":
408
+ activation = nn.ReLU(inplace=True)
409
+ elif act_type == "prelu":
410
+ activation = nn.PReLU(num_parameters=num_feat)
411
+ elif act_type == "leakyrelu":
412
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
413
+ self.body.append(activation)
414
+
415
+ # the last conv
416
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
417
+ # upsample
418
+ self.upsampler = nn.PixelShuffle(upscale)
419
+
420
+ def forward(self, x):
421
+ out = x
422
+ for i in range(0, len(self.body)):
423
+ out = self.body[i](out)
424
+
425
+ out = self.upsampler(out)
426
+ # add the nearest upsampled image, so that the network learns the residual
427
+ base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
428
+ out += base
429
+ return out
430
+
431
+
432
+ ####################
433
+ # Upsampler
434
+ ####################
435
+
436
+
437
+ class Upsample(nn.Module):
438
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
439
+ The input data is assumed to be of the form
440
+ `minibatch x channels x [optional depth] x [optional height] x width`.
441
+ """
442
+
443
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
444
+ super(Upsample, self).__init__()
445
+ if isinstance(scale_factor, tuple):
446
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
447
+ else:
448
+ self.scale_factor = float(scale_factor) if scale_factor else None
449
+ self.mode = mode
450
+ self.size = size
451
+ self.align_corners = align_corners
452
+
453
+ def forward(self, x):
454
+ return nn.functional.interpolate(
455
+ x,
456
+ size=self.size,
457
+ scale_factor=self.scale_factor,
458
+ mode=self.mode,
459
+ align_corners=self.align_corners,
460
+ )
461
+
462
+ def extra_repr(self):
463
+ if self.scale_factor is not None:
464
+ info = f"scale_factor={self.scale_factor}"
465
+ else:
466
+ info = f"size={self.size}"
467
+ info += f", mode={self.mode}"
468
+ return info
469
+
470
+
471
+ def pixel_unshuffle(x, scale):
472
+ """Pixel unshuffle.
473
+ Args:
474
+ x (Tensor): Input feature with shape (b, c, hh, hw).
475
+ scale (int): Downsample ratio.
476
+ Returns:
477
+ Tensor: the pixel unshuffled feature.
478
+ """
479
+ b, c, hh, hw = x.size()
480
+ out_channel = c * (scale**2)
481
+ assert hh % scale == 0 and hw % scale == 0
482
+ h = hh // scale
483
+ w = hw // scale
484
+ x_view = x.view(b, c, h, scale, w, scale)
485
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
486
+
487
+
488
+ def pixelshuffle_block(
489
+ in_nc,
490
+ out_nc,
491
+ upscale_factor=2,
492
+ kernel_size=3,
493
+ stride=1,
494
+ bias=True,
495
+ pad_type="zero",
496
+ norm_type=None,
497
+ act_type="relu",
498
+ convtype="Conv2D",
499
+ ):
500
+ """
501
+ Pixel shuffle layer
502
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
503
+ Neural Network, CVPR17)
504
+ """
505
+ conv = conv_block(
506
+ in_nc,
507
+ out_nc * (upscale_factor**2),
508
+ kernel_size,
509
+ stride,
510
+ bias=bias,
511
+ pad_type=pad_type,
512
+ norm_type=None,
513
+ act_type=None,
514
+ convtype=convtype,
515
+ )
516
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
517
+
518
+ n = norm(norm_type, out_nc) if norm_type else None
519
+ a = act(act_type) if act_type else None
520
+ return sequential(conv, pixel_shuffle, n, a)
521
+
522
+
523
+ def upconv_block(
524
+ in_nc,
525
+ out_nc,
526
+ upscale_factor=2,
527
+ kernel_size=3,
528
+ stride=1,
529
+ bias=True,
530
+ pad_type="zero",
531
+ norm_type=None,
532
+ act_type="relu",
533
+ mode="nearest",
534
+ convtype="Conv2D",
535
+ ):
536
+ """Upconv layer"""
537
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == "Conv3D" else upscale_factor
538
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
539
+ conv = conv_block(
540
+ in_nc,
541
+ out_nc,
542
+ kernel_size,
543
+ stride,
544
+ bias=bias,
545
+ pad_type=pad_type,
546
+ norm_type=norm_type,
547
+ act_type=act_type,
548
+ convtype=convtype,
549
+ )
550
+ return sequential(upsample, conv)
551
+
552
+
553
+ ####################
554
+ # Basic blocks
555
+ ####################
556
+
557
+
558
+ def make_layer(basic_block, num_basic_block, **kwarg):
559
+ """Make layers by stacking the same blocks.
560
+ Args:
561
+ basic_block (nn.module): nn.module class for basic block. (block)
562
+ num_basic_block (int): number of blocks. (n_layers)
563
+ Returns:
564
+ nn.Sequential: Stacked blocks in nn.Sequential.
565
+ """
566
+ layers = []
567
+ for _ in range(num_basic_block):
568
+ layers.append(basic_block(**kwarg))
569
+ return nn.Sequential(*layers)
570
+
571
+
572
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
573
+ """activation helper"""
574
+ act_type = act_type.lower()
575
+ if act_type == "relu":
576
+ layer = nn.ReLU(inplace)
577
+ elif act_type in ("leakyrelu", "lrelu"):
578
+ layer = nn.LeakyReLU(neg_slope, inplace)
579
+ elif act_type == "prelu":
580
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
581
+ elif act_type == "tanh": # [-1, 1] range output
582
+ layer = nn.Tanh()
583
+ elif act_type == "sigmoid": # [0, 1] range output
584
+ layer = nn.Sigmoid()
585
+ else:
586
+ raise NotImplementedError(f"activation layer [{act_type}] is not found")
587
+ return layer
588
+
589
+
590
+ class Identity(nn.Module):
591
+ def __init__(self, *kwargs):
592
+ super(Identity, self).__init__()
593
+
594
+ def forward(self, x, *kwargs):
595
+ return x
596
+
597
+
598
+ def norm(norm_type, nc):
599
+ """Return a normalization layer"""
600
+ norm_type = norm_type.lower()
601
+ if norm_type == "batch":
602
+ layer = nn.BatchNorm2d(nc, affine=True)
603
+ elif norm_type == "instance":
604
+ layer = nn.InstanceNorm2d(nc, affine=False)
605
+ elif norm_type == "none":
606
+
607
+ def norm_layer(x):
608
+ return Identity()
609
+ else:
610
+ raise NotImplementedError(f"normalization layer [{norm_type}] is not found")
611
+ return layer
612
+
613
+
614
+ def pad(pad_type, padding):
615
+ """padding layer helper"""
616
+ pad_type = pad_type.lower()
617
+ if padding == 0:
618
+ return None
619
+ if pad_type == "reflect":
620
+ layer = nn.ReflectionPad2d(padding)
621
+ elif pad_type == "replicate":
622
+ layer = nn.ReplicationPad2d(padding)
623
+ elif pad_type == "zero":
624
+ layer = nn.ZeroPad2d(padding)
625
+ else:
626
+ raise NotImplementedError(f"padding layer [{pad_type}] is not implemented")
627
+ return layer
628
+
629
+
630
+ def get_valid_padding(kernel_size, dilation):
631
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
632
+ padding = (kernel_size - 1) // 2
633
+ return padding
634
+
635
+
636
+ class ShortcutBlock(nn.Module):
637
+ """Elementwise sum the output of a submodule to its input"""
638
+
639
+ def __init__(self, submodule):
640
+ super(ShortcutBlock, self).__init__()
641
+ self.sub = submodule
642
+
643
+ def forward(self, x):
644
+ output = x + self.sub(x)
645
+ return output
646
+
647
+ def __repr__(self):
648
+ return "Identity + \n|" + self.sub.__repr__().replace("\n", "\n|")
649
+
650
+
651
+ def sequential(*args):
652
+ """Flatten Sequential. It unwraps nn.Sequential."""
653
+ if len(args) == 1:
654
+ if isinstance(args[0], OrderedDict):
655
+ raise NotImplementedError("sequential does not support OrderedDict input.")
656
+ return args[0] # No sequential is needed.
657
+ modules = []
658
+ for module in args:
659
+ if isinstance(module, nn.Sequential):
660
+ for submodule in module.children():
661
+ modules.append(submodule)
662
+ elif isinstance(module, nn.Module):
663
+ modules.append(module)
664
+ return nn.Sequential(*modules)
665
+
666
+
667
+ def conv_block(
668
+ in_nc,
669
+ out_nc,
670
+ kernel_size,
671
+ stride=1,
672
+ dilation=1,
673
+ groups=1,
674
+ bias=True,
675
+ pad_type="zero",
676
+ norm_type=None,
677
+ act_type="relu",
678
+ mode="CNA",
679
+ convtype="Conv2D",
680
+ spectral_norm=False,
681
+ ):
682
+ """Conv layer with padding, normalization, activation"""
683
+ assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
684
+ padding = get_valid_padding(kernel_size, dilation)
685
+ p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
686
+ padding = padding if pad_type == "zero" else 0
687
+
688
+ if convtype == "PartialConv2D":
689
+ # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
690
+ from torchvision.ops import PartialConv2d
691
+
692
+ c = PartialConv2d(
693
+ in_nc,
694
+ out_nc,
695
+ kernel_size=kernel_size,
696
+ stride=stride,
697
+ padding=padding,
698
+ dilation=dilation,
699
+ bias=bias,
700
+ groups=groups,
701
+ )
702
+ elif convtype == "DeformConv2D":
703
+ from torchvision.ops import DeformConv2d # not tested
704
+
705
+ c = DeformConv2d(
706
+ in_nc,
707
+ out_nc,
708
+ kernel_size=kernel_size,
709
+ stride=stride,
710
+ padding=padding,
711
+ dilation=dilation,
712
+ bias=bias,
713
+ groups=groups,
714
+ )
715
+ elif convtype == "Conv3D":
716
+ c = nn.Conv3d(
717
+ in_nc,
718
+ out_nc,
719
+ kernel_size=kernel_size,
720
+ stride=stride,
721
+ padding=padding,
722
+ dilation=dilation,
723
+ bias=bias,
724
+ groups=groups,
725
+ )
726
+ else:
727
+ c = nn.Conv2d(
728
+ in_nc,
729
+ out_nc,
730
+ kernel_size=kernel_size,
731
+ stride=stride,
732
+ padding=padding,
733
+ dilation=dilation,
734
+ bias=bias,
735
+ groups=groups,
736
+ )
737
+
738
+ if spectral_norm:
739
+ c = nn.utils.spectral_norm(c)
740
+
741
+ a = act(act_type) if act_type else None
742
+ if "CNA" in mode:
743
+ n = norm(norm_type, out_nc) if norm_type else None
744
+ return sequential(p, c, n, a)
745
+ elif mode == "NAC":
746
+ if norm_type is None and act_type is not None:
747
+ a = act(act_type, inplace=False)
748
+ n = norm(norm_type, in_nc) if norm_type else None
749
+ return sequential(n, a, p, c)
750
+
751
+
752
+ def load_models(
753
+ model_path: Path,
754
+ command_path: str = None,
755
+ ) -> list:
756
+ """
757
+ A one-and done loader to try finding the desired models in specified directories.
758
+
759
+ @param download_name: Specify to download from model_url immediately.
760
+ @param model_url: If no other models are found, this will be downloaded on upscale.
761
+ @param model_path: The location to store/find models in.
762
+ @param command_path: A command-line argument to search for models in first.
763
+ @param ext_filter: An optional list of filename extensions to filter by
764
+ @return: A list of paths containing the desired model(s)
765
+ """
766
+ output = []
767
+
768
+ try:
769
+ places = []
770
+ if command_path is not None and command_path != model_path:
771
+ pretrained_path = os.path.join(command_path, "experiments/pretrained_models")
772
+ if os.path.exists(pretrained_path):
773
+ print(f"Appending path: {pretrained_path}")
774
+ places.append(pretrained_path)
775
+ elif os.path.exists(command_path):
776
+ places.append(command_path)
777
+
778
+ places.append(model_path)
779
+
780
+ except Exception:
781
+ pass
782
+
783
+ return output
784
+
785
+
786
+ def mod2normal(state_dict):
787
+ # this code is copied from https://github.com/victorca25/iNNfer
788
+ if "conv_first.weight" in state_dict:
789
+ crt_net = {}
790
+ items = list(state_dict)
791
+
792
+ crt_net["model.0.weight"] = state_dict["conv_first.weight"]
793
+ crt_net["model.0.bias"] = state_dict["conv_first.bias"]
794
+
795
+ for k in items.copy():
796
+ if "RDB" in k:
797
+ ori_k = k.replace("RRDB_trunk.", "model.1.sub.")
798
+ if ".weight" in k:
799
+ ori_k = ori_k.replace(".weight", ".0.weight")
800
+ elif ".bias" in k:
801
+ ori_k = ori_k.replace(".bias", ".0.bias")
802
+ crt_net[ori_k] = state_dict[k]
803
+ items.remove(k)
804
+
805
+ crt_net["model.1.sub.23.weight"] = state_dict["trunk_conv.weight"]
806
+ crt_net["model.1.sub.23.bias"] = state_dict["trunk_conv.bias"]
807
+ crt_net["model.3.weight"] = state_dict["upconv1.weight"]
808
+ crt_net["model.3.bias"] = state_dict["upconv1.bias"]
809
+ crt_net["model.6.weight"] = state_dict["upconv2.weight"]
810
+ crt_net["model.6.bias"] = state_dict["upconv2.bias"]
811
+ crt_net["model.8.weight"] = state_dict["HRconv.weight"]
812
+ crt_net["model.8.bias"] = state_dict["HRconv.bias"]
813
+ crt_net["model.10.weight"] = state_dict["conv_last.weight"]
814
+ crt_net["model.10.bias"] = state_dict["conv_last.bias"]
815
+ state_dict = crt_net
816
+ return state_dict
817
+
818
+
819
+ def resrgan2normal(state_dict, nb=23):
820
+ # this code is copied from https://github.com/victorca25/iNNfer
821
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
822
+ re8x = 0
823
+ crt_net = {}
824
+ items = list(state_dict)
825
+
826
+ crt_net["model.0.weight"] = state_dict["conv_first.weight"]
827
+ crt_net["model.0.bias"] = state_dict["conv_first.bias"]
828
+
829
+ for k in items.copy():
830
+ if "rdb" in k:
831
+ ori_k = k.replace("body.", "model.1.sub.")
832
+ ori_k = ori_k.replace(".rdb", ".RDB")
833
+ if ".weight" in k:
834
+ ori_k = ori_k.replace(".weight", ".0.weight")
835
+ elif ".bias" in k:
836
+ ori_k = ori_k.replace(".bias", ".0.bias")
837
+ crt_net[ori_k] = state_dict[k]
838
+ items.remove(k)
839
+
840
+ crt_net[f"model.1.sub.{nb}.weight"] = state_dict["conv_body.weight"]
841
+ crt_net[f"model.1.sub.{nb}.bias"] = state_dict["conv_body.bias"]
842
+ crt_net["model.3.weight"] = state_dict["conv_up1.weight"]
843
+ crt_net["model.3.bias"] = state_dict["conv_up1.bias"]
844
+ crt_net["model.6.weight"] = state_dict["conv_up2.weight"]
845
+ crt_net["model.6.bias"] = state_dict["conv_up2.bias"]
846
+
847
+ if "conv_up3.weight" in state_dict:
848
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
849
+ re8x = 3
850
+ crt_net["model.9.weight"] = state_dict["conv_up3.weight"]
851
+ crt_net["model.9.bias"] = state_dict["conv_up3.bias"]
852
+
853
+ crt_net[f"model.{8+re8x}.weight"] = state_dict["conv_hr.weight"]
854
+ crt_net[f"model.{8+re8x}.bias"] = state_dict["conv_hr.bias"]
855
+ crt_net[f"model.{10+re8x}.weight"] = state_dict["conv_last.weight"]
856
+ crt_net[f"model.{10+re8x}.bias"] = state_dict["conv_last.bias"]
857
+
858
+ state_dict = crt_net
859
+ return state_dict
860
+
861
+
862
+ def infer_params(state_dict):
863
+ # this code is copied from https://github.com/victorca25/iNNfer
864
+ scale2x = 0
865
+ scalemin = 6
866
+ n_uplayer = 0
867
+ plus = False
868
+
869
+ for block in list(state_dict):
870
+ parts = block.split(".")
871
+ n_parts = len(parts)
872
+ if n_parts == 5 and parts[2] == "sub":
873
+ nb = int(parts[3])
874
+ elif n_parts == 3:
875
+ part_num = int(parts[1])
876
+ if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
877
+ scale2x += 1
878
+ if part_num > n_uplayer:
879
+ n_uplayer = part_num
880
+ out_nc = state_dict[block].shape[0]
881
+ if not plus and "conv1x1" in block:
882
+ plus = True
883
+
884
+ nf = state_dict["model.0.weight"].shape[0]
885
+ in_nc = state_dict["model.0.weight"].shape[1]
886
+ out_nc = out_nc
887
+ scale = 2**scale2x
888
+
889
+ return in_nc, out_nc, nf, nb, plus, scale
890
+
891
+
892
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
893
+ Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
894
+
895
+
896
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
897
+ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
898
+ w = image.width
899
+ h = image.height
900
+
901
+ non_overlap_width = tile_w - overlap
902
+ non_overlap_height = tile_h - overlap
903
+
904
+ cols = math.ceil((w - overlap) / non_overlap_width)
905
+ rows = math.ceil((h - overlap) / non_overlap_height)
906
+
907
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
908
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
909
+
910
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
911
+ for row in range(rows):
912
+ row_images = []
913
+
914
+ y = int(row * dy)
915
+
916
+ if y + tile_h >= h:
917
+ y = h - tile_h
918
+
919
+ for col in range(cols):
920
+ x = int(col * dx)
921
+
922
+ if x + tile_w >= w:
923
+ x = w - tile_w
924
+
925
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
926
+
927
+ row_images.append([x, tile_w, tile])
928
+
929
+ grid.tiles.append([y, tile_h, row_images])
930
+
931
+ return grid
932
+
933
+
934
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
935
+ def combine_grid(grid):
936
+ def make_mask_image(r):
937
+ r = r * 255 / grid.overlap
938
+ r = r.astype(np.uint8)
939
+ return Image.fromarray(r, "L")
940
+
941
+ mask_w = make_mask_image(
942
+ np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
943
+ )
944
+ mask_h = make_mask_image(
945
+ np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
946
+ )
947
+
948
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
949
+ for y, h, row in grid.tiles:
950
+ combined_row = Image.new("RGB", (grid.image_w, h))
951
+ for x, w, tile in row:
952
+ if x == 0:
953
+ combined_row.paste(tile, (0, 0))
954
+ continue
955
+
956
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
957
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
958
+
959
+ if y == 0:
960
+ combined_image.paste(combined_row, (0, 0))
961
+ continue
962
+
963
+ combined_image.paste(
964
+ combined_row.crop((0, 0, combined_row.width, grid.overlap)),
965
+ (0, y),
966
+ mask=mask_h,
967
+ )
968
+ combined_image.paste(
969
+ combined_row.crop((0, grid.overlap, combined_row.width, h)),
970
+ (0, y + grid.overlap),
971
+ )
972
+
973
+ return combined_image
974
+
975
+
976
+ class UpscalerESRGAN:
977
+ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
978
+ self.device = device
979
+ self.dtype = dtype
980
+ self.model_path = model_path
981
+ self.model = self.load_model(model_path)
982
+
983
+ def __call__(self, img: Image.Image) -> Image.Image:
984
+ return self.upscale_without_tiling(img)
985
+
986
+ def to(self, device: torch.device, dtype: torch.dtype):
987
+ self.device = device
988
+ self.dtype = dtype
989
+ self.model.to(device=device, dtype=dtype)
990
+
991
+ def load_model(self, path: Path) -> SRVGGNetCompact | RRDBNet:
992
+ filename = path
993
+ state_dict = torch.load(filename, weights_only=True, map_location=self.device)
994
+
995
+ if "params_ema" in state_dict:
996
+ state_dict = state_dict["params_ema"]
997
+ elif "params" in state_dict:
998
+ state_dict = state_dict["params"]
999
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
1000
+ model = SRVGGNetCompact(
1001
+ num_in_ch=3,
1002
+ num_out_ch=3,
1003
+ num_feat=64,
1004
+ num_conv=num_conv,
1005
+ upscale=4,
1006
+ act_type="prelu",
1007
+ )
1008
+ model.load_state_dict(state_dict)
1009
+ model.eval()
1010
+ return model
1011
+
1012
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
1013
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
1014
+ state_dict = resrgan2normal(state_dict, nb)
1015
+ elif "conv_first.weight" in state_dict:
1016
+ state_dict = mod2normal(state_dict)
1017
+ elif "model.0.weight" not in state_dict:
1018
+ raise Exception("The file is not a recognized ESRGAN model.")
1019
+
1020
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
1021
+
1022
+ model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
1023
+ model.load_state_dict(state_dict)
1024
+ model.eval()
1025
+
1026
+ return model
1027
+
1028
+ def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
1029
+ img = np.array(img)
1030
+ img = img[:, :, ::-1]
1031
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
1032
+ img = torch.from_numpy(img).float()
1033
+ img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
1034
+ with torch.no_grad():
1035
+ output = self.model(img)
1036
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
1037
+ output = 255.0 * np.moveaxis(output, 0, 2)
1038
+ output = output.astype(np.uint8)
1039
+ output = output[:, :, ::-1]
1040
+ return Image.fromarray(output, "RGB")
1041
+
1042
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
1043
+ def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
1044
+ grid = split_grid(img)
1045
+ newtiles = []
1046
+ scale_factor = 1
1047
+
1048
+ for y, h, row in grid.tiles:
1049
+ newrow = []
1050
+ for tiledata in row:
1051
+ x, w, tile = tiledata
1052
+
1053
+ output = self.upscale_without_tiling(tile)
1054
+ scale_factor = output.width // tile.width
1055
+
1056
+ newrow.append([x * scale_factor, w * scale_factor, output])
1057
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
1058
+
1059
+ newgrid = Grid(
1060
+ newtiles,
1061
+ grid.tile_w * scale_factor,
1062
+ grid.tile_h * scale_factor,
1063
+ grid.image_w * scale_factor,
1064
+ grid.image_h * scale_factor,
1065
+ grid.overlap * scale_factor,
1066
+ )
1067
+ output = combine_grid(newgrid)
1068
+ return output