ali-ghamdan commited on
Commit
26626a1
1 Parent(s): de27012
Files changed (46) hide show
  1. .gitignore +139 -0
  2. .pre-commit-config.yaml +46 -0
  3. CODE_OF_CONDUCT.md +128 -0
  4. Comparisons.md +24 -0
  5. FAQ.md +7 -0
  6. LICENSE +351 -0
  7. MANIFEST.in +8 -0
  8. PaperModel.md +76 -0
  9. README_CN.md +7 -0
  10. VERSION +1 -0
  11. app.py +52 -0
  12. gfpgan/__init__.py +7 -0
  13. gfpgan/archs/__init__.py +10 -0
  14. gfpgan/archs/arcface_arch.py +245 -0
  15. gfpgan/archs/gfpgan_bilinear_arch.py +312 -0
  16. gfpgan/archs/gfpganv1_arch.py +439 -0
  17. gfpgan/archs/gfpganv1_clean_arch.py +324 -0
  18. gfpgan/archs/stylegan2_bilinear_arch.py +613 -0
  19. gfpgan/archs/stylegan2_clean_arch.py +368 -0
  20. gfpgan/data/__init__.py +10 -0
  21. gfpgan/data/ffhq_degradation_dataset.py +230 -0
  22. gfpgan/models/__init__.py +10 -0
  23. gfpgan/models/gfpgan_model.py +579 -0
  24. gfpgan/train.py +11 -0
  25. gfpgan/utils.py +144 -0
  26. gfpgan/weights/README.md +3 -0
  27. inference_gfpgan.py +155 -0
  28. options/train_gfpgan_v1.yml +216 -0
  29. options/train_gfpgan_v1_simple.yml +182 -0
  30. requirements.txt +12 -0
  31. scripts/convert_gfpganv_to_clean.py +164 -0
  32. scripts/parse_landmark.py +85 -0
  33. setup.cfg +33 -0
  34. setup.py +107 -0
  35. tests/data/ffhq_gt.lmdb/data.mdb +0 -0
  36. tests/data/ffhq_gt.lmdb/lock.mdb +0 -0
  37. tests/data/ffhq_gt.lmdb/meta_info.txt +1 -0
  38. tests/data/test_eye_mouth_landmarks.pth +3 -0
  39. tests/data/test_ffhq_degradation_dataset.yml +24 -0
  40. tests/data/test_gfpgan_model.yml +140 -0
  41. tests/test_arcface_arch.py +49 -0
  42. tests/test_ffhq_degradation_dataset.py +96 -0
  43. tests/test_gfpgan_arch.py +203 -0
  44. tests/test_gfpgan_model.py +132 -0
  45. tests/test_stylegan2_clean_arch.py +52 -0
  46. tests/test_utils.py +43 -0
.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignored folders
2
+ datasets/*
3
+ experiments/*
4
+ results/*
5
+ tb_logger/*
6
+ wandb/*
7
+ tmp/*
8
+
9
+ version.py
10
+
11
+ # Byte-compiled / optimized / DLL files
12
+ __pycache__/
13
+ *.py[cod]
14
+ *$py.class
15
+
16
+ # C extensions
17
+ *.so
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ pip-wheel-metadata/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+
64
+ # Translations
65
+ *.mo
66
+ *.pot
67
+
68
+ # Django stuff:
69
+ *.log
70
+ local_settings.py
71
+ db.sqlite3
72
+ db.sqlite3-journal
73
+
74
+ # Flask stuff:
75
+ instance/
76
+ .webassets-cache
77
+
78
+ # Scrapy stuff:
79
+ .scrapy
80
+
81
+ # Sphinx documentation
82
+ docs/_build/
83
+
84
+ # PyBuilder
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105
+ __pypackages__/
106
+
107
+ # Celery stuff
108
+ celerybeat-schedule
109
+ celerybeat.pid
110
+
111
+ # SageMath parsed files
112
+ *.sage.py
113
+
114
+ # Environments
115
+ .env
116
+ .venv
117
+ env/
118
+ venv/
119
+ ENV/
120
+ env.bak/
121
+ venv.bak/
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+ .dmypy.json
136
+ dmypy.json
137
+
138
+ # Pyre type checker
139
+ .pyre/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ # flake8
3
+ - repo: https://github.com/PyCQA/flake8
4
+ rev: 3.8.3
5
+ hooks:
6
+ - id: flake8
7
+ args: ["--config=setup.cfg", "--ignore=W504, W503"]
8
+
9
+ # modify known_third_party
10
+ - repo: https://github.com/asottile/seed-isort-config
11
+ rev: v2.2.0
12
+ hooks:
13
+ - id: seed-isort-config
14
+
15
+ # isort
16
+ - repo: https://github.com/timothycrosley/isort
17
+ rev: 5.2.2
18
+ hooks:
19
+ - id: isort
20
+
21
+ # yapf
22
+ - repo: https://github.com/pre-commit/mirrors-yapf
23
+ rev: v0.30.0
24
+ hooks:
25
+ - id: yapf
26
+
27
+ # codespell
28
+ - repo: https://github.com/codespell-project/codespell
29
+ rev: v2.1.0
30
+ hooks:
31
+ - id: codespell
32
+
33
+ # pre-commit-hooks
34
+ - repo: https://github.com/pre-commit/pre-commit-hooks
35
+ rev: v3.2.0
36
+ hooks:
37
+ - id: trailing-whitespace # Trim trailing whitespace
38
+ - id: check-yaml # Attempt to load all yaml files to verify syntax
39
+ - id: check-merge-conflict # Check for files that contain merge conflict strings
40
+ - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings
41
+ - id: end-of-file-fixer # Make sure files end in a newline and only a newline
42
+ - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0
43
+ - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*-
44
+ args: ["--remove"]
45
+ - id: mixed-line-ending # Replace or check mixed line ending
46
+ args: ["--fix=lf"]
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the
26
+ overall community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or
31
+ advances of any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email
35
+ address, without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ xintao.wang@outlook.com or xintaowang@tencent.com.
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series
86
+ of actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or
93
+ permanent ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within
113
+ the community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.0, available at
119
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120
+
121
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
122
+ enforcement ladder](https://github.com/mozilla/diversity).
123
+
124
+ [homepage]: https://www.contributor-covenant.org
125
+
126
+ For answers to common questions about this code of conduct, see the FAQ at
127
+ https://www.contributor-covenant.org/faq. Translations are available at
128
+ https://www.contributor-covenant.org/translations.
Comparisons.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Comparisons
2
+
3
+ ## Comparisons among different model versions
4
+
5
+ Note that V1.3 is not always better than V1.2. You may need to try different models based on your purpose and inputs.
6
+
7
+ | Version | Strengths | Weaknesses |
8
+ | :---: | :---: | :---: |
9
+ |V1.3 | ✓ natural outputs<br> ✓better results on very low-quality inputs <br> ✓ work on relatively high-quality inputs <br>✓ can have repeated (twice) restorations | ✗ not very sharp <br> ✗ have a slight change on identity |
10
+ |V1.2 | ✓ sharper output <br> ✓ with beauty makeup | ✗ some outputs are unnatural|
11
+
12
+ For the following images, you may need to **zoom in** for comparing details, or **click the image** to see in the full size.
13
+
14
+ | Input | V1 | V1.2 | V1.3
15
+ | :---: | :---: | :---: | :---: |
16
+ |![019_Anne_Hathaway_01_00](https://user-images.githubusercontent.com/17445847/153762146-96b25999-4ddd-42a5-a3fe-bb90565f4c4f.png)| ![](https://user-images.githubusercontent.com/17445847/153762256-ef41e749-5a27-495c-8a9c-d8403be55869.png) | ![](https://user-images.githubusercontent.com/17445847/153762297-d41582fc-6253-4e7e-a1ce-4dc237ae3bf3.png) | ![](https://user-images.githubusercontent.com/17445847/153762215-e0535e94-b5ba-426e-97b5-35c00873604d.png) |
17
+ | ![106_Harry_Styles_00_00](https://user-images.githubusercontent.com/17445847/153789040-632c0eda-c15a-43e9-a63c-9ead64f92d4a.png) | ![](https://user-images.githubusercontent.com/17445847/153789172-93cd4980-5318-4633-a07e-1c8f8064ff89.png) | ![](https://user-images.githubusercontent.com/17445847/153789185-f7b268a7-d1db-47b0-ae4a-335e5d657a18.png) | ![](https://user-images.githubusercontent.com/17445847/153789198-7c7f3bca-0ef0-4494-92f0-20aa6f7d7464.png)|
18
+ | ![076_Paris_Hilton_00_00](https://user-images.githubusercontent.com/17445847/153789607-86387770-9db8-441f-b08a-c9679b121b85.png) | ![](https://user-images.githubusercontent.com/17445847/153789619-e56b438a-78a0-425d-8f44-ec4692a43dda.png) | ![](https://user-images.githubusercontent.com/17445847/153789633-5b28f778-3b7f-4e08-8a1d-740ca6e82d8a.png) | ![](https://user-images.githubusercontent.com/17445847/153789645-bc623f21-b32d-4fc3-bfe9-61203407a180.png)|
19
+ | ![008_George_Clooney_00_00](https://user-images.githubusercontent.com/17445847/153790017-0c3ca94d-1c9d-4a0e-b539-ab12d4da98ff.png) | ![](https://user-images.githubusercontent.com/17445847/153790028-fb0d38ab-399d-4a30-8154-2dcd72ca90e8.png) | ![](https://user-images.githubusercontent.com/17445847/153790044-1ef68e34-6120-4439-a5d9-0b6cdbe9c3d0.png) | ![](https://user-images.githubusercontent.com/17445847/153790059-a8d3cece-8989-4e9a-9ffe-903e1690cfd6.png)|
20
+ | ![057_Madonna_01_00](https://user-images.githubusercontent.com/17445847/153790624-2d0751d0-8fb4-4806-be9d-71b833c2c226.png) | ![](https://user-images.githubusercontent.com/17445847/153790639-7eb870e5-26b2-41dc-b139-b698bb40e6e6.png) | ![](https://user-images.githubusercontent.com/17445847/153790651-86899b7a-a1b6-4242-9e8a-77b462004998.png) | ![](https://user-images.githubusercontent.com/17445847/153790655-c8f6c25b-9b4e-4633-b16f-c43da86cff8f.png)|
21
+ | ![044_Amy_Schumer_01_00](https://user-images.githubusercontent.com/17445847/153790811-3fb4fc46-5b4f-45fe-8fcb-a128de2bfa60.png) | ![](https://user-images.githubusercontent.com/17445847/153790817-d45aa4ff-bfc4-4163-b462-75eef9426fab.png) | ![](https://user-images.githubusercontent.com/17445847/153790824-5f93c3a0-fe5a-42f6-8b4b-5a5de8cd0ac3.png) | ![](https://user-images.githubusercontent.com/17445847/153790835-0edf9944-05c7-41c4-8581-4dc5ffc56c9d.png)|
22
+ | ![012_Jackie_Chan_01_00](https://user-images.githubusercontent.com/17445847/153791176-737b016a-e94f-4898-8db7-43e7762141c9.png) | ![](https://user-images.githubusercontent.com/17445847/153791183-2f25a723-56bf-4cd5-aafe-a35513a6d1c5.png) | ![](https://user-images.githubusercontent.com/17445847/153791194-93416cf9-2b58-4e70-b806-27e14c58d4fd.png) | ![](https://user-images.githubusercontent.com/17445847/153791202-aa98659c-b702-4bce-9c47-a2fa5eccc5ae.png)|
23
+
24
+ <!-- | ![]() | ![]() | ![]() | ![]()| -->
FAQ.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ 1. **How to finetune the GFPGANCleanv1-NoCE-C2 (v1.2) model**
4
+
5
+ **A:** 1) The GFPGANCleanv1-NoCE-C2 (v1.2) model uses the *clean* architecture, which is more friendly for deploying.
6
+ 2) This model is not directly trained. Instead, it is converted from another *bilinear* model.
7
+ 3) If you want to finetune the GFPGANCleanv1-NoCE-C2 (v1.2), you need to finetune its original *bilinear* model, and then do the conversion.
LICENSE ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making GFPGAN available.
2
+
3
+ Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
+
5
+ GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
6
+
7
+
8
+ Terms of the Apache License Version 2.0:
9
+ ---------------------------------------------
10
+ Apache License
11
+
12
+ Version 2.0, January 2004
13
+
14
+ http://www.apache.org/licenses/
15
+
16
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
17
+ 1. Definitions.
18
+
19
+ “License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
20
+
21
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
22
+
23
+ “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
28
+
29
+ “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
30
+
31
+ “Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
32
+
33
+ “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
34
+
35
+ “Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
36
+
37
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
38
+
39
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
40
+
41
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
42
+
43
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
44
+
45
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
46
+
47
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
48
+
49
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
50
+
51
+ If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
52
+
53
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
54
+
55
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
56
+
57
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
58
+
59
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
60
+
61
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
62
+
63
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
64
+
65
+ END OF TERMS AND CONDITIONS
66
+
67
+
68
+
69
+ Other dependencies and licenses:
70
+
71
+
72
+ Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
73
+ ---------------------------------------------
74
+ 1. basicsr
75
+ Copyright 2018-2020 BasicSR Authors
76
+
77
+
78
+ This BasicSR project is released under the Apache 2.0 license.
79
+
80
+ A copy of Apache 2.0 is included in this file.
81
+
82
+ StyleGAN2
83
+ The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
84
+ The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
85
+ DFDNet
86
+ The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
87
+
88
+ Terms of the Nvidia License:
89
+ ---------------------------------------------
90
+
91
+ 1. Definitions
92
+
93
+ "Licensor" means any person or entity that distributes its Work.
94
+
95
+ "Software" means the original work of authorship made available under
96
+ this License.
97
+
98
+ "Work" means the Software and any additions to or derivative works of
99
+ the Software that are made available under this License.
100
+
101
+ "Nvidia Processors" means any central processing unit (CPU), graphics
102
+ processing unit (GPU), field-programmable gate array (FPGA),
103
+ application-specific integrated circuit (ASIC) or any combination
104
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
105
+
106
+ The terms "reproduce," "reproduction," "derivative works," and
107
+ "distribution" have the meaning as provided under U.S. copyright law;
108
+ provided, however, that for the purposes of this License, derivative
109
+ works shall not include works that remain separable from, or merely
110
+ link (or bind by name) to the interfaces of, the Work.
111
+
112
+ Works, including the Software, are "made available" under this License
113
+ by including in or with the Work either (a) a copyright notice
114
+ referencing the applicability of this License to the Work, or (b) a
115
+ copy of this License.
116
+
117
+ 2. License Grants
118
+
119
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
120
+ License, each Licensor grants to you a perpetual, worldwide,
121
+ non-exclusive, royalty-free, copyright license to reproduce,
122
+ prepare derivative works of, publicly display, publicly perform,
123
+ sublicense and distribute its Work and any resulting derivative
124
+ works in any form.
125
+
126
+ 3. Limitations
127
+
128
+ 3.1 Redistribution. You may reproduce or distribute the Work only
129
+ if (a) you do so under this License, (b) you include a complete
130
+ copy of this License with your distribution, and (c) you retain
131
+ without modification any copyright, patent, trademark, or
132
+ attribution notices that are present in the Work.
133
+
134
+ 3.2 Derivative Works. You may specify that additional or different
135
+ terms apply to the use, reproduction, and distribution of your
136
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
137
+ provide that the use limitation in Section 3.3 applies to your
138
+ derivative works, and (b) you identify the specific derivative
139
+ works that are subject to Your Terms. Notwithstanding Your Terms,
140
+ this License (including the redistribution requirements in Section
141
+ 3.1) will continue to apply to the Work itself.
142
+
143
+ 3.3 Use Limitation. The Work and any derivative works thereof only
144
+ may be used or intended for use non-commercially. The Work or
145
+ derivative works thereof may be used or intended for use by Nvidia
146
+ or its affiliates commercially or non-commercially. As used herein,
147
+ "non-commercially" means for research or evaluation purposes only.
148
+
149
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
150
+ against any Licensor (including any claim, cross-claim or
151
+ counterclaim in a lawsuit) to enforce any patents that you allege
152
+ are infringed by any Work, then your rights under this License from
153
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
154
+ terminate immediately.
155
+
156
+ 3.5 Trademarks. This License does not grant any rights to use any
157
+ Licensor's or its affiliates' names, logos, or trademarks, except
158
+ as necessary to reproduce the notices described in this License.
159
+
160
+ 3.6 Termination. If you violate any term of this License, then your
161
+ rights under this License (including the grants in Sections 2.1 and
162
+ 2.2) will terminate immediately.
163
+
164
+ 4. Disclaimer of Warranty.
165
+
166
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
167
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
168
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
169
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
170
+ THIS LICENSE.
171
+
172
+ 5. Limitation of Liability.
173
+
174
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
175
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
176
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
177
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
178
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
179
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
180
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
181
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
182
+ THE POSSIBILITY OF SUCH DAMAGES.
183
+
184
+ MIT License
185
+
186
+ Copyright (c) 2019 Kim Seonghyeon
187
+
188
+ Permission is hereby granted, free of charge, to any person obtaining a copy
189
+ of this software and associated documentation files (the "Software"), to deal
190
+ in the Software without restriction, including without limitation the rights
191
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
192
+ copies of the Software, and to permit persons to whom the Software is
193
+ furnished to do so, subject to the following conditions:
194
+
195
+ The above copyright notice and this permission notice shall be included in all
196
+ copies or substantial portions of the Software.
197
+
198
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
199
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
200
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
201
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
202
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
203
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
204
+ SOFTWARE.
205
+
206
+
207
+
208
+ Open Source Software licensed under the BSD 3-Clause license:
209
+ ---------------------------------------------
210
+ 1. torchvision
211
+ Copyright (c) Soumith Chintala 2016,
212
+ All rights reserved.
213
+
214
+ 2. torch
215
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
216
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
217
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
218
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
219
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
220
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
221
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
222
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
223
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
224
+
225
+
226
+ Terms of the BSD 3-Clause License:
227
+ ---------------------------------------------
228
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
229
+
230
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
231
+
232
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
233
+
234
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
235
+
236
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
237
+
238
+
239
+
240
+ Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
241
+ ---------------------------------------------
242
+ 1. numpy
243
+ Copyright (c) 2005-2020, NumPy Developers.
244
+ All rights reserved.
245
+
246
+ A copy of BSD 3-Clause License is included in this file.
247
+
248
+ The NumPy repository and source distributions bundle several libraries that are
249
+ compatibly licensed. We list these here.
250
+
251
+ Name: Numpydoc
252
+ Files: doc/sphinxext/numpydoc/*
253
+ License: BSD-2-Clause
254
+ For details, see doc/sphinxext/LICENSE.txt
255
+
256
+ Name: scipy-sphinx-theme
257
+ Files: doc/scipy-sphinx-theme/*
258
+ License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
259
+ For details, see doc/scipy-sphinx-theme/LICENSE.txt
260
+
261
+ Name: lapack-lite
262
+ Files: numpy/linalg/lapack_lite/*
263
+ License: BSD-3-Clause
264
+ For details, see numpy/linalg/lapack_lite/LICENSE.txt
265
+
266
+ Name: tempita
267
+ Files: tools/npy_tempita/*
268
+ License: MIT
269
+ For details, see tools/npy_tempita/license.txt
270
+
271
+ Name: dragon4
272
+ Files: numpy/core/src/multiarray/dragon4.c
273
+ License: MIT
274
+ For license text, see numpy/core/src/multiarray/dragon4.c
275
+
276
+
277
+
278
+ Open Source Software licensed under the MIT license:
279
+ ---------------------------------------------
280
+ 1. facexlib
281
+ Copyright (c) 2020 Xintao Wang
282
+
283
+ 2. opencv-python
284
+ Copyright (c) Olli-Pekka Heinisuo
285
+ Please note that only files in cv2 package are used.
286
+
287
+
288
+ Terms of the MIT License:
289
+ ---------------------------------------------
290
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
291
+
292
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
293
+
294
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
295
+
296
+
297
+
298
+ Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
299
+ ---------------------------------------------
300
+ 1. tqdm
301
+ Copyright (c) 2013 noamraph
302
+
303
+ `tqdm` is a product of collaborative work.
304
+ Unless otherwise stated, all authors (see commit logs) retain copyright
305
+ for their respective work, and release the work under the MIT licence
306
+ (text below).
307
+
308
+ Exceptions or notable authors are listed below
309
+ in reverse chronological order:
310
+
311
+ * files: *
312
+ MPLv2.0 2015-2020 (c) Casper da Costa-Luis
313
+ [casperdcl](https://github.com/casperdcl).
314
+ * files: tqdm/_tqdm.py
315
+ MIT 2016 (c) [PR #96] on behalf of Google Inc.
316
+ * files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
317
+ MIT 2013 (c) Noam Yorav-Raphael, original author.
318
+
319
+ [PR #96]: https://github.com/tqdm/tqdm/pull/96
320
+
321
+
322
+ Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
323
+ -----------------------------------------------
324
+
325
+ This Source Code Form is subject to the terms of the
326
+ Mozilla Public License, v. 2.0.
327
+ If a copy of the MPL was not distributed with this file,
328
+ You can obtain one at https://mozilla.org/MPL/2.0/.
329
+
330
+
331
+ MIT License (MIT)
332
+ -----------------
333
+
334
+ Copyright (c) 2013 noamraph
335
+
336
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
337
+ this software and associated documentation files (the "Software"), to deal in
338
+ the Software without restriction, including without limitation the rights to
339
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
340
+ the Software, and to permit persons to whom the Software is furnished to do so,
341
+ subject to the following conditions:
342
+
343
+ The above copyright notice and this permission notice shall be included in all
344
+ copies or substantial portions of the Software.
345
+
346
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
347
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
348
+ FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
349
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
350
+ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
351
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ include assets/*
2
+ include inputs/*
3
+ include scripts/*.py
4
+ include inference_gfpgan.py
5
+ include VERSION
6
+ include LICENSE
7
+ include requirements.txt
8
+ include gfpgan/weights/README.md
PaperModel.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation
2
+
3
+ We now provide a *clean* version of GFPGAN, which does not require customized CUDA extensions. See [here](README.md#installation) for this easier installation.<br>
4
+ If you want want to use the original model in our paper, please follow the instructions below.
5
+
6
+ 1. Clone repo
7
+
8
+ ```bash
9
+ git clone https://github.com/xinntao/GFPGAN.git
10
+ cd GFPGAN
11
+ ```
12
+
13
+ 1. Install dependent packages
14
+
15
+ As StyleGAN2 uses customized PyTorch C++ extensions, you need to **compile them during installation** or **load them just-in-time(JIT)**.
16
+ You can refer to [BasicSR-INSTALL.md](https://github.com/xinntao/BasicSR/blob/master/INSTALL.md) for more details.
17
+
18
+ **Option 1: Load extensions just-in-time(JIT)** (For those just want to do simple inferences, may have less issues)
19
+
20
+ ```bash
21
+ # Install basicsr - https://github.com/xinntao/BasicSR
22
+ # We use BasicSR for both training and inference
23
+ pip install basicsr
24
+
25
+ # Install facexlib - https://github.com/xinntao/facexlib
26
+ # We use face detection and face restoration helper in the facexlib package
27
+ pip install facexlib
28
+
29
+ pip install -r requirements.txt
30
+ python setup.py develop
31
+
32
+ # remember to set BASICSR_JIT=True before your running commands
33
+ ```
34
+
35
+ **Option 2: Compile extensions during installation** (For those need to train/inference for many times)
36
+
37
+ ```bash
38
+ # Install basicsr - https://github.com/xinntao/BasicSR
39
+ # We use BasicSR for both training and inference
40
+ # Set BASICSR_EXT=True to compile the cuda extensions in the BasicSR - It may take several minutes to compile, please be patient
41
+ # Add -vvv for detailed log prints
42
+ BASICSR_EXT=True pip install basicsr -vvv
43
+
44
+ # Install facexlib - https://github.com/xinntao/facexlib
45
+ # We use face detection and face restoration helper in the facexlib package
46
+ pip install facexlib
47
+
48
+ pip install -r requirements.txt
49
+ python setup.py develop
50
+ ```
51
+
52
+ ## :zap: Quick Inference
53
+
54
+ Download pre-trained models: [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth)
55
+
56
+ ```bash
57
+ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth -P experiments/pretrained_models
58
+ ```
59
+
60
+ - Option 1: Load extensions just-in-time(JIT)
61
+
62
+ ```bash
63
+ BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1
64
+
65
+ # for aligned images
66
+ BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned
67
+ ```
68
+
69
+ - Option 2: Have successfully compiled extensions during installation
70
+
71
+ ```bash
72
+ python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1
73
+
74
+ # for aligned images
75
+ python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned
76
+ ```
README_CN.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/gfpgan_logo.png" height=130>
3
+ </p>
4
+
5
+ ## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
6
+
7
+ 还未完工,欢迎贡献!
VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 1.3.4
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ from gfpgan import GFPGANer
6
+
7
+ # installing version 1 of GFPGAN
8
+ os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth')
9
+ # installing version 1.2 of GFPGAN
10
+ os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth')
11
+ # installing version 1.3 of GFPGAN (latest)
12
+ os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth')
13
+
14
+
15
+ def interface(image: Image, model: str = "GFPGANv1.3.pth"):
16
+ if model == "":
17
+ model = "GFPGANv1.3.pth"
18
+ if model != "GFPGANv1.pth" and model != "GFPGANCleanv1-NoCE-C2.pth" and model != "GFPGANv1.3.pth":
19
+ model = "GFPGANv1.3.pth"
20
+ restorer = GFPGANer(
21
+ model_path=model,
22
+ arch="original" if model == "GFPGANv1.pth" else "clean",
23
+ bg_upsampler=None,
24
+ channel_multiplier=1 if model == "GFPGANv1.pth" else 2,
25
+ upscale=2)
26
+ img = np.array(image)[:, :, ::-1].copy()
27
+ cropped_faces, restored_faces, restored_img = restorer.enhance(
28
+ img,
29
+ align=False,
30
+ only_center_face=False,
31
+ )
32
+ return restored_img
33
+
34
+
35
+ gr.Interface(
36
+ interface,
37
+ [
38
+ gr.components.Image(
39
+ type="pil",
40
+ label="Image",
41
+ ),
42
+ gr.components.Radio([
43
+ "GFPGANv1.pth",
44
+ "GFPGANCleanv1-NoCE-C2.pth",
45
+ "GFPGANv1.3.pth",
46
+ ],
47
+ label="model",
48
+ default="GFPGANv1.3.pth",
49
+ type="value")
50
+ ],
51
+ [gr.components.Image(label="Enhanced Image")],
52
+ ).launch()
gfpgan/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ from .archs import *
3
+ from .data import *
4
+ from .models import *
5
+ from .utils import *
6
+
7
+ # from .version import *
gfpgan/archs/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import arch modules for registry
6
+ # scan all the files that end with '_arch.py' under the archs folder
7
+ arch_folder = osp.dirname(osp.abspath(__file__))
8
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
9
+ # import all the arch modules
10
+ _arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
gfpgan/archs/arcface_arch.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from basicsr.utils.registry import ARCH_REGISTRY
3
+
4
+
5
+ def conv3x3(inplanes, outplanes, stride=1):
6
+ """A simple wrapper for 3x3 convolution with padding.
7
+
8
+ Args:
9
+ inplanes (int): Channel number of inputs.
10
+ outplanes (int): Channel number of outputs.
11
+ stride (int): Stride in convolution. Default: 1.
12
+ """
13
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14
+
15
+
16
+ class BasicBlock(nn.Module):
17
+ """Basic residual block used in the ResNetArcFace architecture.
18
+
19
+ Args:
20
+ inplanes (int): Channel number of inputs.
21
+ planes (int): Channel number of outputs.
22
+ stride (int): Stride in convolution. Default: 1.
23
+ downsample (nn.Module): The downsample module. Default: None.
24
+ """
25
+ expansion = 1 # output channel expansion ratio
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = nn.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = nn.BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class IRBlock(nn.Module):
57
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
+
59
+ Args:
60
+ inplanes (int): Channel number of inputs.
61
+ planes (int): Channel number of outputs.
62
+ stride (int): Stride in convolution. Default: 1.
63
+ downsample (nn.Module): The downsample module. Default: None.
64
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
+ """
66
+ expansion = 1 # output channel expansion ratio
67
+
68
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
+ super(IRBlock, self).__init__()
70
+ self.bn0 = nn.BatchNorm2d(inplanes)
71
+ self.conv1 = conv3x3(inplanes, inplanes)
72
+ self.bn1 = nn.BatchNorm2d(inplanes)
73
+ self.prelu = nn.PReLU()
74
+ self.conv2 = conv3x3(inplanes, planes, stride)
75
+ self.bn2 = nn.BatchNorm2d(planes)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+ self.use_se = use_se
79
+ if self.use_se:
80
+ self.se = SEBlock(planes)
81
+
82
+ def forward(self, x):
83
+ residual = x
84
+ out = self.bn0(x)
85
+ out = self.conv1(out)
86
+ out = self.bn1(out)
87
+ out = self.prelu(out)
88
+
89
+ out = self.conv2(out)
90
+ out = self.bn2(out)
91
+ if self.use_se:
92
+ out = self.se(out)
93
+
94
+ if self.downsample is not None:
95
+ residual = self.downsample(x)
96
+
97
+ out += residual
98
+ out = self.prelu(out)
99
+
100
+ return out
101
+
102
+
103
+ class Bottleneck(nn.Module):
104
+ """Bottleneck block used in the ResNetArcFace architecture.
105
+
106
+ Args:
107
+ inplanes (int): Channel number of inputs.
108
+ planes (int): Channel number of outputs.
109
+ stride (int): Stride in convolution. Default: 1.
110
+ downsample (nn.Module): The downsample module. Default: None.
111
+ """
112
+ expansion = 4 # output channel expansion ratio
113
+
114
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
115
+ super(Bottleneck, self).__init__()
116
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117
+ self.bn1 = nn.BatchNorm2d(planes)
118
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
119
+ self.bn2 = nn.BatchNorm2d(planes)
120
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
121
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
122
+ self.relu = nn.ReLU(inplace=True)
123
+ self.downsample = downsample
124
+ self.stride = stride
125
+
126
+ def forward(self, x):
127
+ residual = x
128
+
129
+ out = self.conv1(x)
130
+ out = self.bn1(out)
131
+ out = self.relu(out)
132
+
133
+ out = self.conv2(out)
134
+ out = self.bn2(out)
135
+ out = self.relu(out)
136
+
137
+ out = self.conv3(out)
138
+ out = self.bn3(out)
139
+
140
+ if self.downsample is not None:
141
+ residual = self.downsample(x)
142
+
143
+ out += residual
144
+ out = self.relu(out)
145
+
146
+ return out
147
+
148
+
149
+ class SEBlock(nn.Module):
150
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
+
152
+ Args:
153
+ channel (int): Channel number of inputs.
154
+ reduction (int): Channel reduction ration. Default: 16.
155
+ """
156
+
157
+ def __init__(self, channel, reduction=16):
158
+ super(SEBlock, self).__init__()
159
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
160
+ self.fc = nn.Sequential(
161
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
+ nn.Sigmoid())
163
+
164
+ def forward(self, x):
165
+ b, c, _, _ = x.size()
166
+ y = self.avg_pool(x).view(b, c)
167
+ y = self.fc(y).view(b, c, 1, 1)
168
+ return x * y
169
+
170
+
171
+ @ARCH_REGISTRY.register()
172
+ class ResNetArcFace(nn.Module):
173
+ """ArcFace with ResNet architectures.
174
+
175
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
+
177
+ Args:
178
+ block (str): Block used in the ArcFace architecture.
179
+ layers (tuple(int)): Block numbers in each layer.
180
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
+ """
182
+
183
+ def __init__(self, block, layers, use_se=True):
184
+ if block == 'IRBlock':
185
+ block = IRBlock
186
+ self.inplanes = 64
187
+ self.use_se = use_se
188
+ super(ResNetArcFace, self).__init__()
189
+
190
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
+ self.bn1 = nn.BatchNorm2d(64)
192
+ self.prelu = nn.PReLU()
193
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
194
+ self.layer1 = self._make_layer(block, 64, layers[0])
195
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
+ self.bn4 = nn.BatchNorm2d(512)
199
+ self.dropout = nn.Dropout()
200
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
+ self.bn5 = nn.BatchNorm1d(512)
202
+
203
+ # initialization
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.xavier_normal_(m.weight)
207
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
208
+ nn.init.constant_(m.weight, 1)
209
+ nn.init.constant_(m.bias, 0)
210
+ elif isinstance(m, nn.Linear):
211
+ nn.init.xavier_normal_(m.weight)
212
+ nn.init.constant_(m.bias, 0)
213
+
214
+ def _make_layer(self, block, planes, num_blocks, stride=1):
215
+ downsample = None
216
+ if stride != 1 or self.inplanes != planes * block.expansion:
217
+ downsample = nn.Sequential(
218
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
219
+ nn.BatchNorm2d(planes * block.expansion),
220
+ )
221
+ layers = []
222
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
+ self.inplanes = planes
224
+ for _ in range(1, num_blocks):
225
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
226
+
227
+ return nn.Sequential(*layers)
228
+
229
+ def forward(self, x):
230
+ x = self.conv1(x)
231
+ x = self.bn1(x)
232
+ x = self.prelu(x)
233
+ x = self.maxpool(x)
234
+
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.bn4(x)
240
+ x = self.dropout(x)
241
+ x = x.view(x.size(0), -1)
242
+ x = self.fc5(x)
243
+ x = self.bn5(x)
244
+
245
+ return x
gfpgan/archs/gfpgan_bilinear_arch.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.utils.registry import ARCH_REGISTRY
5
+ from torch import nn
6
+
7
+ from .gfpganv1_arch import ResUpBlock
8
+ from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
9
+ StyleGAN2GeneratorBilinear)
10
+
11
+
12
+ class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
13
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
14
+
15
+ It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
16
+ deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
17
+
18
+ Args:
19
+ out_size (int): The spatial size of outputs.
20
+ num_style_feat (int): Channel number of style features. Default: 512.
21
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
22
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
23
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
24
+ narrow (float): The narrow ratio for channels. Default: 1.
25
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
26
+ """
27
+
28
+ def __init__(self,
29
+ out_size,
30
+ num_style_feat=512,
31
+ num_mlp=8,
32
+ channel_multiplier=2,
33
+ lr_mlp=0.01,
34
+ narrow=1,
35
+ sft_half=False):
36
+ super(StyleGAN2GeneratorBilinearSFT, self).__init__(
37
+ out_size,
38
+ num_style_feat=num_style_feat,
39
+ num_mlp=num_mlp,
40
+ channel_multiplier=channel_multiplier,
41
+ lr_mlp=lr_mlp,
42
+ narrow=narrow)
43
+ self.sft_half = sft_half
44
+
45
+ def forward(self,
46
+ styles,
47
+ conditions,
48
+ input_is_latent=False,
49
+ noise=None,
50
+ randomize_noise=True,
51
+ truncation=1,
52
+ truncation_latent=None,
53
+ inject_index=None,
54
+ return_latents=False):
55
+ """Forward function for StyleGAN2GeneratorBilinearSFT.
56
+
57
+ Args:
58
+ styles (list[Tensor]): Sample codes of styles.
59
+ conditions (list[Tensor]): SFT conditions to generators.
60
+ input_is_latent (bool): Whether input is latent style. Default: False.
61
+ noise (Tensor | None): Input noise or None. Default: None.
62
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
63
+ truncation (float): The truncation ratio. Default: 1.
64
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
65
+ inject_index (int | None): The injection index for mixing noise. Default: None.
66
+ return_latents (bool): Whether to return style latents. Default: False.
67
+ """
68
+ # style codes -> latents with Style MLP layer
69
+ if not input_is_latent:
70
+ styles = [self.style_mlp(s) for s in styles]
71
+ # noises
72
+ if noise is None:
73
+ if randomize_noise:
74
+ noise = [None] * self.num_layers # for each style conv layer
75
+ else: # use the stored noise
76
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
77
+ # style truncation
78
+ if truncation < 1:
79
+ style_truncation = []
80
+ for style in styles:
81
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
82
+ styles = style_truncation
83
+ # get style latents with injection
84
+ if len(styles) == 1:
85
+ inject_index = self.num_latent
86
+
87
+ if styles[0].ndim < 3:
88
+ # repeat latent code for all the layers
89
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
90
+ else: # used for encoder with different latent code for each layer
91
+ latent = styles[0]
92
+ elif len(styles) == 2: # mixing noises
93
+ if inject_index is None:
94
+ inject_index = random.randint(1, self.num_latent - 1)
95
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
96
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
97
+ latent = torch.cat([latent1, latent2], 1)
98
+
99
+ # main generation
100
+ out = self.constant_input(latent.shape[0])
101
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
102
+ skip = self.to_rgb1(out, latent[:, 1])
103
+
104
+ i = 1
105
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
106
+ noise[2::2], self.to_rgbs):
107
+ out = conv1(out, latent[:, i], noise=noise1)
108
+
109
+ # the conditions may have fewer levels
110
+ if i < len(conditions):
111
+ # SFT part to combine the conditions
112
+ if self.sft_half: # only apply SFT to half of the channels
113
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
114
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
115
+ out = torch.cat([out_same, out_sft], dim=1)
116
+ else: # apply SFT to all the channels
117
+ out = out * conditions[i - 1] + conditions[i]
118
+
119
+ out = conv2(out, latent[:, i + 1], noise=noise2)
120
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
121
+ i += 2
122
+
123
+ image = skip
124
+
125
+ if return_latents:
126
+ return image, latent
127
+ else:
128
+ return image, None
129
+
130
+
131
+ @ARCH_REGISTRY.register()
132
+ class GFPGANBilinear(nn.Module):
133
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
134
+
135
+ It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
136
+ deployment. It can be easily converted to the clean version: GFPGANv1Clean.
137
+
138
+
139
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
140
+
141
+ Args:
142
+ out_size (int): The spatial size of outputs.
143
+ num_style_feat (int): Channel number of style features. Default: 512.
144
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
145
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
146
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
147
+
148
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
149
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
150
+ input_is_latent (bool): Whether input is latent style. Default: False.
151
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
152
+ narrow (float): The narrow ratio for channels. Default: 1.
153
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ out_size,
159
+ num_style_feat=512,
160
+ channel_multiplier=1,
161
+ decoder_load_path=None,
162
+ fix_decoder=True,
163
+ # for stylegan decoder
164
+ num_mlp=8,
165
+ lr_mlp=0.01,
166
+ input_is_latent=False,
167
+ different_w=False,
168
+ narrow=1,
169
+ sft_half=False):
170
+
171
+ super(GFPGANBilinear, self).__init__()
172
+ self.input_is_latent = input_is_latent
173
+ self.different_w = different_w
174
+ self.num_style_feat = num_style_feat
175
+
176
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
177
+ channels = {
178
+ '4': int(512 * unet_narrow),
179
+ '8': int(512 * unet_narrow),
180
+ '16': int(512 * unet_narrow),
181
+ '32': int(512 * unet_narrow),
182
+ '64': int(256 * channel_multiplier * unet_narrow),
183
+ '128': int(128 * channel_multiplier * unet_narrow),
184
+ '256': int(64 * channel_multiplier * unet_narrow),
185
+ '512': int(32 * channel_multiplier * unet_narrow),
186
+ '1024': int(16 * channel_multiplier * unet_narrow)
187
+ }
188
+
189
+ self.log_size = int(math.log(out_size, 2))
190
+ first_out_size = 2**(int(math.log(out_size, 2)))
191
+
192
+ self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
193
+
194
+ # downsample
195
+ in_channels = channels[f'{first_out_size}']
196
+ self.conv_body_down = nn.ModuleList()
197
+ for i in range(self.log_size, 2, -1):
198
+ out_channels = channels[f'{2**(i - 1)}']
199
+ self.conv_body_down.append(ResBlock(in_channels, out_channels))
200
+ in_channels = out_channels
201
+
202
+ self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
203
+
204
+ # upsample
205
+ in_channels = channels['4']
206
+ self.conv_body_up = nn.ModuleList()
207
+ for i in range(3, self.log_size + 1):
208
+ out_channels = channels[f'{2**i}']
209
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
210
+ in_channels = out_channels
211
+
212
+ # to RGB
213
+ self.toRGB = nn.ModuleList()
214
+ for i in range(3, self.log_size + 1):
215
+ self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
216
+
217
+ if different_w:
218
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
219
+ else:
220
+ linear_out_channel = num_style_feat
221
+
222
+ self.final_linear = EqualLinear(
223
+ channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
224
+
225
+ # the decoder: stylegan2 generator with SFT modulations
226
+ self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
227
+ out_size=out_size,
228
+ num_style_feat=num_style_feat,
229
+ num_mlp=num_mlp,
230
+ channel_multiplier=channel_multiplier,
231
+ lr_mlp=lr_mlp,
232
+ narrow=narrow,
233
+ sft_half=sft_half)
234
+
235
+ # load pre-trained stylegan2 model if necessary
236
+ if decoder_load_path:
237
+ self.stylegan_decoder.load_state_dict(
238
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
239
+ # fix decoder without updating params
240
+ if fix_decoder:
241
+ for _, param in self.stylegan_decoder.named_parameters():
242
+ param.requires_grad = False
243
+
244
+ # for SFT modulations (scale and shift)
245
+ self.condition_scale = nn.ModuleList()
246
+ self.condition_shift = nn.ModuleList()
247
+ for i in range(3, self.log_size + 1):
248
+ out_channels = channels[f'{2**i}']
249
+ if sft_half:
250
+ sft_out_channels = out_channels
251
+ else:
252
+ sft_out_channels = out_channels * 2
253
+ self.condition_scale.append(
254
+ nn.Sequential(
255
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
256
+ ScaledLeakyReLU(0.2),
257
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
258
+ self.condition_shift.append(
259
+ nn.Sequential(
260
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
261
+ ScaledLeakyReLU(0.2),
262
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
263
+
264
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
265
+ """Forward function for GFPGANBilinear.
266
+
267
+ Args:
268
+ x (Tensor): Input images.
269
+ return_latents (bool): Whether to return style latents. Default: False.
270
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
271
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
272
+ """
273
+ conditions = []
274
+ unet_skips = []
275
+ out_rgbs = []
276
+
277
+ # encoder
278
+ feat = self.conv_body_first(x)
279
+ for i in range(self.log_size - 2):
280
+ feat = self.conv_body_down[i](feat)
281
+ unet_skips.insert(0, feat)
282
+
283
+ feat = self.final_conv(feat)
284
+
285
+ # style code
286
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
287
+ if self.different_w:
288
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
289
+
290
+ # decode
291
+ for i in range(self.log_size - 2):
292
+ # add unet skip
293
+ feat = feat + unet_skips[i]
294
+ # ResUpLayer
295
+ feat = self.conv_body_up[i](feat)
296
+ # generate scale and shift for SFT layers
297
+ scale = self.condition_scale[i](feat)
298
+ conditions.append(scale.clone())
299
+ shift = self.condition_shift[i](feat)
300
+ conditions.append(shift.clone())
301
+ # generate rgb images
302
+ if return_rgb:
303
+ out_rgbs.append(self.toRGB[i](feat))
304
+
305
+ # decoder
306
+ image, _ = self.stylegan_decoder([style_code],
307
+ conditions,
308
+ return_latents=return_latents,
309
+ input_is_latent=self.input_is_latent,
310
+ randomize_noise=randomize_noise)
311
+
312
+ return image, out_rgbs
gfpgan/archs/gfpganv1_arch.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
5
+ StyleGAN2Generator)
6
+ from basicsr.ops.fused_act import FusedLeakyReLU
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
13
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
14
+
15
+ Args:
16
+ out_size (int): The spatial size of outputs.
17
+ num_style_feat (int): Channel number of style features. Default: 512.
18
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
19
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
20
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
21
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
22
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
23
+ narrow (float): The narrow ratio for channels. Default: 1.
24
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
25
+ """
26
+
27
+ def __init__(self,
28
+ out_size,
29
+ num_style_feat=512,
30
+ num_mlp=8,
31
+ channel_multiplier=2,
32
+ resample_kernel=(1, 3, 3, 1),
33
+ lr_mlp=0.01,
34
+ narrow=1,
35
+ sft_half=False):
36
+ super(StyleGAN2GeneratorSFT, self).__init__(
37
+ out_size,
38
+ num_style_feat=num_style_feat,
39
+ num_mlp=num_mlp,
40
+ channel_multiplier=channel_multiplier,
41
+ resample_kernel=resample_kernel,
42
+ lr_mlp=lr_mlp,
43
+ narrow=narrow)
44
+ self.sft_half = sft_half
45
+
46
+ def forward(self,
47
+ styles,
48
+ conditions,
49
+ input_is_latent=False,
50
+ noise=None,
51
+ randomize_noise=True,
52
+ truncation=1,
53
+ truncation_latent=None,
54
+ inject_index=None,
55
+ return_latents=False):
56
+ """Forward function for StyleGAN2GeneratorSFT.
57
+
58
+ Args:
59
+ styles (list[Tensor]): Sample codes of styles.
60
+ conditions (list[Tensor]): SFT conditions to generators.
61
+ input_is_latent (bool): Whether input is latent style. Default: False.
62
+ noise (Tensor | None): Input noise or None. Default: None.
63
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
64
+ truncation (float): The truncation ratio. Default: 1.
65
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
66
+ inject_index (int | None): The injection index for mixing noise. Default: None.
67
+ return_latents (bool): Whether to return style latents. Default: False.
68
+ """
69
+ # style codes -> latents with Style MLP layer
70
+ if not input_is_latent:
71
+ styles = [self.style_mlp(s) for s in styles]
72
+ # noises
73
+ if noise is None:
74
+ if randomize_noise:
75
+ noise = [None] * self.num_layers # for each style conv layer
76
+ else: # use the stored noise
77
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
78
+ # style truncation
79
+ if truncation < 1:
80
+ style_truncation = []
81
+ for style in styles:
82
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
83
+ styles = style_truncation
84
+ # get style latents with injection
85
+ if len(styles) == 1:
86
+ inject_index = self.num_latent
87
+
88
+ if styles[0].ndim < 3:
89
+ # repeat latent code for all the layers
90
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
91
+ else: # used for encoder with different latent code for each layer
92
+ latent = styles[0]
93
+ elif len(styles) == 2: # mixing noises
94
+ if inject_index is None:
95
+ inject_index = random.randint(1, self.num_latent - 1)
96
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
97
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
98
+ latent = torch.cat([latent1, latent2], 1)
99
+
100
+ # main generation
101
+ out = self.constant_input(latent.shape[0])
102
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
103
+ skip = self.to_rgb1(out, latent[:, 1])
104
+
105
+ i = 1
106
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
107
+ noise[2::2], self.to_rgbs):
108
+ out = conv1(out, latent[:, i], noise=noise1)
109
+
110
+ # the conditions may have fewer levels
111
+ if i < len(conditions):
112
+ # SFT part to combine the conditions
113
+ if self.sft_half: # only apply SFT to half of the channels
114
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
115
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
116
+ out = torch.cat([out_same, out_sft], dim=1)
117
+ else: # apply SFT to all the channels
118
+ out = out * conditions[i - 1] + conditions[i]
119
+
120
+ out = conv2(out, latent[:, i + 1], noise=noise2)
121
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
122
+ i += 2
123
+
124
+ image = skip
125
+
126
+ if return_latents:
127
+ return image, latent
128
+ else:
129
+ return image, None
130
+
131
+
132
+ class ConvUpLayer(nn.Module):
133
+ """Convolutional upsampling layer. It uses bilinear upsampler + Conv.
134
+
135
+ Args:
136
+ in_channels (int): Channel number of the input.
137
+ out_channels (int): Channel number of the output.
138
+ kernel_size (int): Size of the convolving kernel.
139
+ stride (int): Stride of the convolution. Default: 1
140
+ padding (int): Zero-padding added to both sides of the input. Default: 0.
141
+ bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
142
+ bias_init_val (float): Bias initialized value. Default: 0.
143
+ activate (bool): Whether use activateion. Default: True.
144
+ """
145
+
146
+ def __init__(self,
147
+ in_channels,
148
+ out_channels,
149
+ kernel_size,
150
+ stride=1,
151
+ padding=0,
152
+ bias=True,
153
+ bias_init_val=0,
154
+ activate=True):
155
+ super(ConvUpLayer, self).__init__()
156
+ self.in_channels = in_channels
157
+ self.out_channels = out_channels
158
+ self.kernel_size = kernel_size
159
+ self.stride = stride
160
+ self.padding = padding
161
+ # self.scale is used to scale the convolution weights, which is related to the common initializations.
162
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
163
+
164
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
165
+
166
+ if bias and not activate:
167
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
168
+ else:
169
+ self.register_parameter('bias', None)
170
+
171
+ # activation
172
+ if activate:
173
+ if bias:
174
+ self.activation = FusedLeakyReLU(out_channels)
175
+ else:
176
+ self.activation = ScaledLeakyReLU(0.2)
177
+ else:
178
+ self.activation = None
179
+
180
+ def forward(self, x):
181
+ # bilinear upsample
182
+ out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
183
+ # conv
184
+ out = F.conv2d(
185
+ out,
186
+ self.weight * self.scale,
187
+ bias=self.bias,
188
+ stride=self.stride,
189
+ padding=self.padding,
190
+ )
191
+ # activation
192
+ if self.activation is not None:
193
+ out = self.activation(out)
194
+ return out
195
+
196
+
197
+ class ResUpBlock(nn.Module):
198
+ """Residual block with upsampling.
199
+
200
+ Args:
201
+ in_channels (int): Channel number of the input.
202
+ out_channels (int): Channel number of the output.
203
+ """
204
+
205
+ def __init__(self, in_channels, out_channels):
206
+ super(ResUpBlock, self).__init__()
207
+
208
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
209
+ self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
210
+ self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
211
+
212
+ def forward(self, x):
213
+ out = self.conv1(x)
214
+ out = self.conv2(out)
215
+ skip = self.skip(x)
216
+ out = (out + skip) / math.sqrt(2)
217
+ return out
218
+
219
+
220
+ @ARCH_REGISTRY.register()
221
+ class GFPGANv1(nn.Module):
222
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
223
+
224
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
225
+
226
+ Args:
227
+ out_size (int): The spatial size of outputs.
228
+ num_style_feat (int): Channel number of style features. Default: 512.
229
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
230
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
231
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
232
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
233
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
234
+
235
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
236
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
237
+ input_is_latent (bool): Whether input is latent style. Default: False.
238
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
239
+ narrow (float): The narrow ratio for channels. Default: 1.
240
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ out_size,
246
+ num_style_feat=512,
247
+ channel_multiplier=1,
248
+ resample_kernel=(1, 3, 3, 1),
249
+ decoder_load_path=None,
250
+ fix_decoder=True,
251
+ # for stylegan decoder
252
+ num_mlp=8,
253
+ lr_mlp=0.01,
254
+ input_is_latent=False,
255
+ different_w=False,
256
+ narrow=1,
257
+ sft_half=False):
258
+
259
+ super(GFPGANv1, self).__init__()
260
+ self.input_is_latent = input_is_latent
261
+ self.different_w = different_w
262
+ self.num_style_feat = num_style_feat
263
+
264
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
265
+ channels = {
266
+ '4': int(512 * unet_narrow),
267
+ '8': int(512 * unet_narrow),
268
+ '16': int(512 * unet_narrow),
269
+ '32': int(512 * unet_narrow),
270
+ '64': int(256 * channel_multiplier * unet_narrow),
271
+ '128': int(128 * channel_multiplier * unet_narrow),
272
+ '256': int(64 * channel_multiplier * unet_narrow),
273
+ '512': int(32 * channel_multiplier * unet_narrow),
274
+ '1024': int(16 * channel_multiplier * unet_narrow)
275
+ }
276
+
277
+ self.log_size = int(math.log(out_size, 2))
278
+ first_out_size = 2**(int(math.log(out_size, 2)))
279
+
280
+ self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
281
+
282
+ # downsample
283
+ in_channels = channels[f'{first_out_size}']
284
+ self.conv_body_down = nn.ModuleList()
285
+ for i in range(self.log_size, 2, -1):
286
+ out_channels = channels[f'{2**(i - 1)}']
287
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
288
+ in_channels = out_channels
289
+
290
+ self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
291
+
292
+ # upsample
293
+ in_channels = channels['4']
294
+ self.conv_body_up = nn.ModuleList()
295
+ for i in range(3, self.log_size + 1):
296
+ out_channels = channels[f'{2**i}']
297
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
298
+ in_channels = out_channels
299
+
300
+ # to RGB
301
+ self.toRGB = nn.ModuleList()
302
+ for i in range(3, self.log_size + 1):
303
+ self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
304
+
305
+ if different_w:
306
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
307
+ else:
308
+ linear_out_channel = num_style_feat
309
+
310
+ self.final_linear = EqualLinear(
311
+ channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
312
+
313
+ # the decoder: stylegan2 generator with SFT modulations
314
+ self.stylegan_decoder = StyleGAN2GeneratorSFT(
315
+ out_size=out_size,
316
+ num_style_feat=num_style_feat,
317
+ num_mlp=num_mlp,
318
+ channel_multiplier=channel_multiplier,
319
+ resample_kernel=resample_kernel,
320
+ lr_mlp=lr_mlp,
321
+ narrow=narrow,
322
+ sft_half=sft_half)
323
+
324
+ # load pre-trained stylegan2 model if necessary
325
+ if decoder_load_path:
326
+ self.stylegan_decoder.load_state_dict(
327
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
328
+ # fix decoder without updating params
329
+ if fix_decoder:
330
+ for _, param in self.stylegan_decoder.named_parameters():
331
+ param.requires_grad = False
332
+
333
+ # for SFT modulations (scale and shift)
334
+ self.condition_scale = nn.ModuleList()
335
+ self.condition_shift = nn.ModuleList()
336
+ for i in range(3, self.log_size + 1):
337
+ out_channels = channels[f'{2**i}']
338
+ if sft_half:
339
+ sft_out_channels = out_channels
340
+ else:
341
+ sft_out_channels = out_channels * 2
342
+ self.condition_scale.append(
343
+ nn.Sequential(
344
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
345
+ ScaledLeakyReLU(0.2),
346
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
347
+ self.condition_shift.append(
348
+ nn.Sequential(
349
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
350
+ ScaledLeakyReLU(0.2),
351
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
352
+
353
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
354
+ """Forward function for GFPGANv1.
355
+
356
+ Args:
357
+ x (Tensor): Input images.
358
+ return_latents (bool): Whether to return style latents. Default: False.
359
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
360
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
361
+ """
362
+ conditions = []
363
+ unet_skips = []
364
+ out_rgbs = []
365
+
366
+ # encoder
367
+ feat = self.conv_body_first(x)
368
+ for i in range(self.log_size - 2):
369
+ feat = self.conv_body_down[i](feat)
370
+ unet_skips.insert(0, feat)
371
+
372
+ feat = self.final_conv(feat)
373
+
374
+ # style code
375
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
376
+ if self.different_w:
377
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
378
+
379
+ # decode
380
+ for i in range(self.log_size - 2):
381
+ # add unet skip
382
+ feat = feat + unet_skips[i]
383
+ # ResUpLayer
384
+ feat = self.conv_body_up[i](feat)
385
+ # generate scale and shift for SFT layers
386
+ scale = self.condition_scale[i](feat)
387
+ conditions.append(scale.clone())
388
+ shift = self.condition_shift[i](feat)
389
+ conditions.append(shift.clone())
390
+ # generate rgb images
391
+ if return_rgb:
392
+ out_rgbs.append(self.toRGB[i](feat))
393
+
394
+ # decoder
395
+ image, _ = self.stylegan_decoder([style_code],
396
+ conditions,
397
+ return_latents=return_latents,
398
+ input_is_latent=self.input_is_latent,
399
+ randomize_noise=randomize_noise)
400
+
401
+ return image, out_rgbs
402
+
403
+
404
+ @ARCH_REGISTRY.register()
405
+ class FacialComponentDiscriminator(nn.Module):
406
+ """Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
407
+ """
408
+
409
+ def __init__(self):
410
+ super(FacialComponentDiscriminator, self).__init__()
411
+ # It now uses a VGG-style architectrue with fixed model size
412
+ self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
413
+ self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
414
+ self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
415
+ self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
416
+ self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
417
+ self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
418
+
419
+ def forward(self, x, return_feats=False):
420
+ """Forward function for FacialComponentDiscriminator.
421
+
422
+ Args:
423
+ x (Tensor): Input images.
424
+ return_feats (bool): Whether to return intermediate features. Default: False.
425
+ """
426
+ feat = self.conv1(x)
427
+ feat = self.conv3(self.conv2(feat))
428
+ rlt_feats = []
429
+ if return_feats:
430
+ rlt_feats.append(feat.clone())
431
+ feat = self.conv5(self.conv4(feat))
432
+ if return_feats:
433
+ rlt_feats.append(feat.clone())
434
+ out = self.final_conv(feat)
435
+
436
+ if return_feats:
437
+ return out, rlt_feats
438
+ else:
439
+ return out, None
gfpgan/archs/gfpganv1_clean_arch.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.utils.registry import ARCH_REGISTRY
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
9
+
10
+
11
+ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
12
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
13
+
14
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
15
+
16
+ Args:
17
+ out_size (int): The spatial size of outputs.
18
+ num_style_feat (int): Channel number of style features. Default: 512.
19
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
20
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
21
+ narrow (float): The narrow ratio for channels. Default: 1.
22
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
23
+ """
24
+
25
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
26
+ super(StyleGAN2GeneratorCSFT, self).__init__(
27
+ out_size,
28
+ num_style_feat=num_style_feat,
29
+ num_mlp=num_mlp,
30
+ channel_multiplier=channel_multiplier,
31
+ narrow=narrow)
32
+ self.sft_half = sft_half
33
+
34
+ def forward(self,
35
+ styles,
36
+ conditions,
37
+ input_is_latent=False,
38
+ noise=None,
39
+ randomize_noise=True,
40
+ truncation=1,
41
+ truncation_latent=None,
42
+ inject_index=None,
43
+ return_latents=False):
44
+ """Forward function for StyleGAN2GeneratorCSFT.
45
+
46
+ Args:
47
+ styles (list[Tensor]): Sample codes of styles.
48
+ conditions (list[Tensor]): SFT conditions to generators.
49
+ input_is_latent (bool): Whether input is latent style. Default: False.
50
+ noise (Tensor | None): Input noise or None. Default: None.
51
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
52
+ truncation (float): The truncation ratio. Default: 1.
53
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
54
+ inject_index (int | None): The injection index for mixing noise. Default: None.
55
+ return_latents (bool): Whether to return style latents. Default: False.
56
+ """
57
+ # style codes -> latents with Style MLP layer
58
+ if not input_is_latent:
59
+ styles = [self.style_mlp(s) for s in styles]
60
+ # noises
61
+ if noise is None:
62
+ if randomize_noise:
63
+ noise = [None] * self.num_layers # for each style conv layer
64
+ else: # use the stored noise
65
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
66
+ # style truncation
67
+ if truncation < 1:
68
+ style_truncation = []
69
+ for style in styles:
70
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
71
+ styles = style_truncation
72
+ # get style latents with injection
73
+ if len(styles) == 1:
74
+ inject_index = self.num_latent
75
+
76
+ if styles[0].ndim < 3:
77
+ # repeat latent code for all the layers
78
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
79
+ else: # used for encoder with different latent code for each layer
80
+ latent = styles[0]
81
+ elif len(styles) == 2: # mixing noises
82
+ if inject_index is None:
83
+ inject_index = random.randint(1, self.num_latent - 1)
84
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
85
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
86
+ latent = torch.cat([latent1, latent2], 1)
87
+
88
+ # main generation
89
+ out = self.constant_input(latent.shape[0])
90
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
91
+ skip = self.to_rgb1(out, latent[:, 1])
92
+
93
+ i = 1
94
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
95
+ noise[2::2], self.to_rgbs):
96
+ out = conv1(out, latent[:, i], noise=noise1)
97
+
98
+ # the conditions may have fewer levels
99
+ if i < len(conditions):
100
+ # SFT part to combine the conditions
101
+ if self.sft_half: # only apply SFT to half of the channels
102
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
103
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
104
+ out = torch.cat([out_same, out_sft], dim=1)
105
+ else: # apply SFT to all the channels
106
+ out = out * conditions[i - 1] + conditions[i]
107
+
108
+ out = conv2(out, latent[:, i + 1], noise=noise2)
109
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
110
+ i += 2
111
+
112
+ image = skip
113
+
114
+ if return_latents:
115
+ return image, latent
116
+ else:
117
+ return image, None
118
+
119
+
120
+ class ResBlock(nn.Module):
121
+ """Residual block with bilinear upsampling/downsampling.
122
+
123
+ Args:
124
+ in_channels (int): Channel number of the input.
125
+ out_channels (int): Channel number of the output.
126
+ mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
127
+ """
128
+
129
+ def __init__(self, in_channels, out_channels, mode='down'):
130
+ super(ResBlock, self).__init__()
131
+
132
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
133
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
134
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
135
+ if mode == 'down':
136
+ self.scale_factor = 0.5
137
+ elif mode == 'up':
138
+ self.scale_factor = 2
139
+
140
+ def forward(self, x):
141
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
142
+ # upsample/downsample
143
+ out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
144
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
145
+ # skip
146
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
147
+ skip = self.skip(x)
148
+ out = out + skip
149
+ return out
150
+
151
+
152
+ @ARCH_REGISTRY.register()
153
+ class GFPGANv1Clean(nn.Module):
154
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
155
+
156
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
157
+
158
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
159
+
160
+ Args:
161
+ out_size (int): The spatial size of outputs.
162
+ num_style_feat (int): Channel number of style features. Default: 512.
163
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
164
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
165
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
166
+
167
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
168
+ input_is_latent (bool): Whether input is latent style. Default: False.
169
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
170
+ narrow (float): The narrow ratio for channels. Default: 1.
171
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ out_size,
177
+ num_style_feat=512,
178
+ channel_multiplier=1,
179
+ decoder_load_path=None,
180
+ fix_decoder=True,
181
+ # for stylegan decoder
182
+ num_mlp=8,
183
+ input_is_latent=False,
184
+ different_w=False,
185
+ narrow=1,
186
+ sft_half=False):
187
+
188
+ super(GFPGANv1Clean, self).__init__()
189
+ self.input_is_latent = input_is_latent
190
+ self.different_w = different_w
191
+ self.num_style_feat = num_style_feat
192
+
193
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
194
+ channels = {
195
+ '4': int(512 * unet_narrow),
196
+ '8': int(512 * unet_narrow),
197
+ '16': int(512 * unet_narrow),
198
+ '32': int(512 * unet_narrow),
199
+ '64': int(256 * channel_multiplier * unet_narrow),
200
+ '128': int(128 * channel_multiplier * unet_narrow),
201
+ '256': int(64 * channel_multiplier * unet_narrow),
202
+ '512': int(32 * channel_multiplier * unet_narrow),
203
+ '1024': int(16 * channel_multiplier * unet_narrow)
204
+ }
205
+
206
+ self.log_size = int(math.log(out_size, 2))
207
+ first_out_size = 2**(int(math.log(out_size, 2)))
208
+
209
+ self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
210
+
211
+ # downsample
212
+ in_channels = channels[f'{first_out_size}']
213
+ self.conv_body_down = nn.ModuleList()
214
+ for i in range(self.log_size, 2, -1):
215
+ out_channels = channels[f'{2**(i - 1)}']
216
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
217
+ in_channels = out_channels
218
+
219
+ self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
220
+
221
+ # upsample
222
+ in_channels = channels['4']
223
+ self.conv_body_up = nn.ModuleList()
224
+ for i in range(3, self.log_size + 1):
225
+ out_channels = channels[f'{2**i}']
226
+ self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
227
+ in_channels = out_channels
228
+
229
+ # to RGB
230
+ self.toRGB = nn.ModuleList()
231
+ for i in range(3, self.log_size + 1):
232
+ self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
233
+
234
+ if different_w:
235
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
236
+ else:
237
+ linear_out_channel = num_style_feat
238
+
239
+ self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
240
+
241
+ # the decoder: stylegan2 generator with SFT modulations
242
+ self.stylegan_decoder = StyleGAN2GeneratorCSFT(
243
+ out_size=out_size,
244
+ num_style_feat=num_style_feat,
245
+ num_mlp=num_mlp,
246
+ channel_multiplier=channel_multiplier,
247
+ narrow=narrow,
248
+ sft_half=sft_half)
249
+
250
+ # load pre-trained stylegan2 model if necessary
251
+ if decoder_load_path:
252
+ self.stylegan_decoder.load_state_dict(
253
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
254
+ # fix decoder without updating params
255
+ if fix_decoder:
256
+ for _, param in self.stylegan_decoder.named_parameters():
257
+ param.requires_grad = False
258
+
259
+ # for SFT modulations (scale and shift)
260
+ self.condition_scale = nn.ModuleList()
261
+ self.condition_shift = nn.ModuleList()
262
+ for i in range(3, self.log_size + 1):
263
+ out_channels = channels[f'{2**i}']
264
+ if sft_half:
265
+ sft_out_channels = out_channels
266
+ else:
267
+ sft_out_channels = out_channels * 2
268
+ self.condition_scale.append(
269
+ nn.Sequential(
270
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
271
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
272
+ self.condition_shift.append(
273
+ nn.Sequential(
274
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
275
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
276
+
277
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
278
+ """Forward function for GFPGANv1Clean.
279
+
280
+ Args:
281
+ x (Tensor): Input images.
282
+ return_latents (bool): Whether to return style latents. Default: False.
283
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
284
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
285
+ """
286
+ conditions = []
287
+ unet_skips = []
288
+ out_rgbs = []
289
+
290
+ # encoder
291
+ feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
292
+ for i in range(self.log_size - 2):
293
+ feat = self.conv_body_down[i](feat)
294
+ unet_skips.insert(0, feat)
295
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
296
+
297
+ # style code
298
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
299
+ if self.different_w:
300
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
301
+
302
+ # decode
303
+ for i in range(self.log_size - 2):
304
+ # add unet skip
305
+ feat = feat + unet_skips[i]
306
+ # ResUpLayer
307
+ feat = self.conv_body_up[i](feat)
308
+ # generate scale and shift for SFT layers
309
+ scale = self.condition_scale[i](feat)
310
+ conditions.append(scale.clone())
311
+ shift = self.condition_shift[i](feat)
312
+ conditions.append(shift.clone())
313
+ # generate rgb images
314
+ if return_rgb:
315
+ out_rgbs.append(self.toRGB[i](feat))
316
+
317
+ # decoder
318
+ image, _ = self.stylegan_decoder([style_code],
319
+ conditions,
320
+ return_latents=return_latents,
321
+ input_is_latent=self.input_is_latent,
322
+ randomize_noise=randomize_noise)
323
+
324
+ return image, out_rgbs
gfpgan/archs/stylegan2_bilinear_arch.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class NormStyleCode(nn.Module):
11
+
12
+ def forward(self, x):
13
+ """Normalize the style codes.
14
+
15
+ Args:
16
+ x (Tensor): Style codes with shape (b, c).
17
+
18
+ Returns:
19
+ Tensor: Normalized tensor.
20
+ """
21
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22
+
23
+
24
+ class EqualLinear(nn.Module):
25
+ """Equalized Linear as StyleGAN2.
26
+
27
+ Args:
28
+ in_channels (int): Size of each sample.
29
+ out_channels (int): Size of each output sample.
30
+ bias (bool): If set to ``False``, the layer will not learn an additive
31
+ bias. Default: ``True``.
32
+ bias_init_val (float): Bias initialized value. Default: 0.
33
+ lr_mul (float): Learning rate multiplier. Default: 1.
34
+ activation (None | str): The activation after ``linear`` operation.
35
+ Supported: 'fused_lrelu', None. Default: None.
36
+ """
37
+
38
+ def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
39
+ super(EqualLinear, self).__init__()
40
+ self.in_channels = in_channels
41
+ self.out_channels = out_channels
42
+ self.lr_mul = lr_mul
43
+ self.activation = activation
44
+ if self.activation not in ['fused_lrelu', None]:
45
+ raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
46
+ "Supported ones are: ['fused_lrelu', None].")
47
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
48
+
49
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
50
+ if bias:
51
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
52
+ else:
53
+ self.register_parameter('bias', None)
54
+
55
+ def forward(self, x):
56
+ if self.bias is None:
57
+ bias = None
58
+ else:
59
+ bias = self.bias * self.lr_mul
60
+ if self.activation == 'fused_lrelu':
61
+ out = F.linear(x, self.weight * self.scale)
62
+ out = fused_leaky_relu(out, bias)
63
+ else:
64
+ out = F.linear(x, self.weight * self.scale, bias=bias)
65
+ return out
66
+
67
+ def __repr__(self):
68
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
69
+ f'out_channels={self.out_channels}, bias={self.bias is not None})')
70
+
71
+
72
+ class ModulatedConv2d(nn.Module):
73
+ """Modulated Conv2d used in StyleGAN2.
74
+
75
+ There is no bias in ModulatedConv2d.
76
+
77
+ Args:
78
+ in_channels (int): Channel number of the input.
79
+ out_channels (int): Channel number of the output.
80
+ kernel_size (int): Size of the convolving kernel.
81
+ num_style_feat (int): Channel number of style features.
82
+ demodulate (bool): Whether to demodulate in the conv layer.
83
+ Default: True.
84
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
85
+ Default: None.
86
+ eps (float): A value added to the denominator for numerical stability.
87
+ Default: 1e-8.
88
+ """
89
+
90
+ def __init__(self,
91
+ in_channels,
92
+ out_channels,
93
+ kernel_size,
94
+ num_style_feat,
95
+ demodulate=True,
96
+ sample_mode=None,
97
+ eps=1e-8,
98
+ interpolation_mode='bilinear'):
99
+ super(ModulatedConv2d, self).__init__()
100
+ self.in_channels = in_channels
101
+ self.out_channels = out_channels
102
+ self.kernel_size = kernel_size
103
+ self.demodulate = demodulate
104
+ self.sample_mode = sample_mode
105
+ self.eps = eps
106
+ self.interpolation_mode = interpolation_mode
107
+ if self.interpolation_mode == 'nearest':
108
+ self.align_corners = None
109
+ else:
110
+ self.align_corners = False
111
+
112
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
113
+ # modulation inside each modulated conv
114
+ self.modulation = EqualLinear(
115
+ num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
116
+
117
+ self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
118
+ self.padding = kernel_size // 2
119
+
120
+ def forward(self, x, style):
121
+ """Forward function.
122
+
123
+ Args:
124
+ x (Tensor): Tensor with shape (b, c, h, w).
125
+ style (Tensor): Tensor with shape (b, num_style_feat).
126
+
127
+ Returns:
128
+ Tensor: Modulated tensor after convolution.
129
+ """
130
+ b, c, h, w = x.shape # c = c_in
131
+ # weight modulation
132
+ style = self.modulation(style).view(b, 1, c, 1, 1)
133
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
134
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
135
+
136
+ if self.demodulate:
137
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
138
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
139
+
140
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
141
+
142
+ if self.sample_mode == 'upsample':
143
+ x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
144
+ elif self.sample_mode == 'downsample':
145
+ x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
146
+
147
+ b, c, h, w = x.shape
148
+ x = x.view(1, b * c, h, w)
149
+ # weight: (b*c_out, c_in, k, k), groups=b
150
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
151
+ out = out.view(b, self.out_channels, *out.shape[2:4])
152
+
153
+ return out
154
+
155
+ def __repr__(self):
156
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
157
+ f'out_channels={self.out_channels}, '
158
+ f'kernel_size={self.kernel_size}, '
159
+ f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
160
+
161
+
162
+ class StyleConv(nn.Module):
163
+ """Style conv.
164
+
165
+ Args:
166
+ in_channels (int): Channel number of the input.
167
+ out_channels (int): Channel number of the output.
168
+ kernel_size (int): Size of the convolving kernel.
169
+ num_style_feat (int): Channel number of style features.
170
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
171
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
172
+ Default: None.
173
+ """
174
+
175
+ def __init__(self,
176
+ in_channels,
177
+ out_channels,
178
+ kernel_size,
179
+ num_style_feat,
180
+ demodulate=True,
181
+ sample_mode=None,
182
+ interpolation_mode='bilinear'):
183
+ super(StyleConv, self).__init__()
184
+ self.modulated_conv = ModulatedConv2d(
185
+ in_channels,
186
+ out_channels,
187
+ kernel_size,
188
+ num_style_feat,
189
+ demodulate=demodulate,
190
+ sample_mode=sample_mode,
191
+ interpolation_mode=interpolation_mode)
192
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
193
+ self.activate = FusedLeakyReLU(out_channels)
194
+
195
+ def forward(self, x, style, noise=None):
196
+ # modulate
197
+ out = self.modulated_conv(x, style)
198
+ # noise injection
199
+ if noise is None:
200
+ b, _, h, w = out.shape
201
+ noise = out.new_empty(b, 1, h, w).normal_()
202
+ out = out + self.weight * noise
203
+ # activation (with bias)
204
+ out = self.activate(out)
205
+ return out
206
+
207
+
208
+ class ToRGB(nn.Module):
209
+ """To RGB from features.
210
+
211
+ Args:
212
+ in_channels (int): Channel number of input.
213
+ num_style_feat (int): Channel number of style features.
214
+ upsample (bool): Whether to upsample. Default: True.
215
+ """
216
+
217
+ def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
218
+ super(ToRGB, self).__init__()
219
+ self.upsample = upsample
220
+ self.interpolation_mode = interpolation_mode
221
+ if self.interpolation_mode == 'nearest':
222
+ self.align_corners = None
223
+ else:
224
+ self.align_corners = False
225
+ self.modulated_conv = ModulatedConv2d(
226
+ in_channels,
227
+ 3,
228
+ kernel_size=1,
229
+ num_style_feat=num_style_feat,
230
+ demodulate=False,
231
+ sample_mode=None,
232
+ interpolation_mode=interpolation_mode)
233
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
234
+
235
+ def forward(self, x, style, skip=None):
236
+ """Forward function.
237
+
238
+ Args:
239
+ x (Tensor): Feature tensor with shape (b, c, h, w).
240
+ style (Tensor): Tensor with shape (b, num_style_feat).
241
+ skip (Tensor): Base/skip tensor. Default: None.
242
+
243
+ Returns:
244
+ Tensor: RGB images.
245
+ """
246
+ out = self.modulated_conv(x, style)
247
+ out = out + self.bias
248
+ if skip is not None:
249
+ if self.upsample:
250
+ skip = F.interpolate(
251
+ skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
252
+ out = out + skip
253
+ return out
254
+
255
+
256
+ class ConstantInput(nn.Module):
257
+ """Constant input.
258
+
259
+ Args:
260
+ num_channel (int): Channel number of constant input.
261
+ size (int): Spatial size of constant input.
262
+ """
263
+
264
+ def __init__(self, num_channel, size):
265
+ super(ConstantInput, self).__init__()
266
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
267
+
268
+ def forward(self, batch):
269
+ out = self.weight.repeat(batch, 1, 1, 1)
270
+ return out
271
+
272
+
273
+ @ARCH_REGISTRY.register()
274
+ class StyleGAN2GeneratorBilinear(nn.Module):
275
+ """StyleGAN2 Generator.
276
+
277
+ Args:
278
+ out_size (int): The spatial size of outputs.
279
+ num_style_feat (int): Channel number of style features. Default: 512.
280
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
281
+ channel_multiplier (int): Channel multiplier for large networks of
282
+ StyleGAN2. Default: 2.
283
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
284
+ narrow (float): Narrow ratio for channels. Default: 1.0.
285
+ """
286
+
287
+ def __init__(self,
288
+ out_size,
289
+ num_style_feat=512,
290
+ num_mlp=8,
291
+ channel_multiplier=2,
292
+ lr_mlp=0.01,
293
+ narrow=1,
294
+ interpolation_mode='bilinear'):
295
+ super(StyleGAN2GeneratorBilinear, self).__init__()
296
+ # Style MLP layers
297
+ self.num_style_feat = num_style_feat
298
+ style_mlp_layers = [NormStyleCode()]
299
+ for i in range(num_mlp):
300
+ style_mlp_layers.append(
301
+ EqualLinear(
302
+ num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
303
+ activation='fused_lrelu'))
304
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
305
+
306
+ channels = {
307
+ '4': int(512 * narrow),
308
+ '8': int(512 * narrow),
309
+ '16': int(512 * narrow),
310
+ '32': int(512 * narrow),
311
+ '64': int(256 * channel_multiplier * narrow),
312
+ '128': int(128 * channel_multiplier * narrow),
313
+ '256': int(64 * channel_multiplier * narrow),
314
+ '512': int(32 * channel_multiplier * narrow),
315
+ '1024': int(16 * channel_multiplier * narrow)
316
+ }
317
+ self.channels = channels
318
+
319
+ self.constant_input = ConstantInput(channels['4'], size=4)
320
+ self.style_conv1 = StyleConv(
321
+ channels['4'],
322
+ channels['4'],
323
+ kernel_size=3,
324
+ num_style_feat=num_style_feat,
325
+ demodulate=True,
326
+ sample_mode=None,
327
+ interpolation_mode=interpolation_mode)
328
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
329
+
330
+ self.log_size = int(math.log(out_size, 2))
331
+ self.num_layers = (self.log_size - 2) * 2 + 1
332
+ self.num_latent = self.log_size * 2 - 2
333
+
334
+ self.style_convs = nn.ModuleList()
335
+ self.to_rgbs = nn.ModuleList()
336
+ self.noises = nn.Module()
337
+
338
+ in_channels = channels['4']
339
+ # noise
340
+ for layer_idx in range(self.num_layers):
341
+ resolution = 2**((layer_idx + 5) // 2)
342
+ shape = [1, 1, resolution, resolution]
343
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
344
+ # style convs and to_rgbs
345
+ for i in range(3, self.log_size + 1):
346
+ out_channels = channels[f'{2**i}']
347
+ self.style_convs.append(
348
+ StyleConv(
349
+ in_channels,
350
+ out_channels,
351
+ kernel_size=3,
352
+ num_style_feat=num_style_feat,
353
+ demodulate=True,
354
+ sample_mode='upsample',
355
+ interpolation_mode=interpolation_mode))
356
+ self.style_convs.append(
357
+ StyleConv(
358
+ out_channels,
359
+ out_channels,
360
+ kernel_size=3,
361
+ num_style_feat=num_style_feat,
362
+ demodulate=True,
363
+ sample_mode=None,
364
+ interpolation_mode=interpolation_mode))
365
+ self.to_rgbs.append(
366
+ ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
367
+ in_channels = out_channels
368
+
369
+ def make_noise(self):
370
+ """Make noise for noise injection."""
371
+ device = self.constant_input.weight.device
372
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
373
+
374
+ for i in range(3, self.log_size + 1):
375
+ for _ in range(2):
376
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
377
+
378
+ return noises
379
+
380
+ def get_latent(self, x):
381
+ return self.style_mlp(x)
382
+
383
+ def mean_latent(self, num_latent):
384
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
385
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
386
+ return latent
387
+
388
+ def forward(self,
389
+ styles,
390
+ input_is_latent=False,
391
+ noise=None,
392
+ randomize_noise=True,
393
+ truncation=1,
394
+ truncation_latent=None,
395
+ inject_index=None,
396
+ return_latents=False):
397
+ """Forward function for StyleGAN2Generator.
398
+
399
+ Args:
400
+ styles (list[Tensor]): Sample codes of styles.
401
+ input_is_latent (bool): Whether input is latent style.
402
+ Default: False.
403
+ noise (Tensor | None): Input noise or None. Default: None.
404
+ randomize_noise (bool): Randomize noise, used when 'noise' is
405
+ False. Default: True.
406
+ truncation (float): TODO. Default: 1.
407
+ truncation_latent (Tensor | None): TODO. Default: None.
408
+ inject_index (int | None): The injection index for mixing noise.
409
+ Default: None.
410
+ return_latents (bool): Whether to return style latents.
411
+ Default: False.
412
+ """
413
+ # style codes -> latents with Style MLP layer
414
+ if not input_is_latent:
415
+ styles = [self.style_mlp(s) for s in styles]
416
+ # noises
417
+ if noise is None:
418
+ if randomize_noise:
419
+ noise = [None] * self.num_layers # for each style conv layer
420
+ else: # use the stored noise
421
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
422
+ # style truncation
423
+ if truncation < 1:
424
+ style_truncation = []
425
+ for style in styles:
426
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
427
+ styles = style_truncation
428
+ # get style latent with injection
429
+ if len(styles) == 1:
430
+ inject_index = self.num_latent
431
+
432
+ if styles[0].ndim < 3:
433
+ # repeat latent code for all the layers
434
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
435
+ else: # used for encoder with different latent code for each layer
436
+ latent = styles[0]
437
+ elif len(styles) == 2: # mixing noises
438
+ if inject_index is None:
439
+ inject_index = random.randint(1, self.num_latent - 1)
440
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
441
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
442
+ latent = torch.cat([latent1, latent2], 1)
443
+
444
+ # main generation
445
+ out = self.constant_input(latent.shape[0])
446
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
447
+ skip = self.to_rgb1(out, latent[:, 1])
448
+
449
+ i = 1
450
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
451
+ noise[2::2], self.to_rgbs):
452
+ out = conv1(out, latent[:, i], noise=noise1)
453
+ out = conv2(out, latent[:, i + 1], noise=noise2)
454
+ skip = to_rgb(out, latent[:, i + 2], skip)
455
+ i += 2
456
+
457
+ image = skip
458
+
459
+ if return_latents:
460
+ return image, latent
461
+ else:
462
+ return image, None
463
+
464
+
465
+ class ScaledLeakyReLU(nn.Module):
466
+ """Scaled LeakyReLU.
467
+
468
+ Args:
469
+ negative_slope (float): Negative slope. Default: 0.2.
470
+ """
471
+
472
+ def __init__(self, negative_slope=0.2):
473
+ super(ScaledLeakyReLU, self).__init__()
474
+ self.negative_slope = negative_slope
475
+
476
+ def forward(self, x):
477
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
478
+ return out * math.sqrt(2)
479
+
480
+
481
+ class EqualConv2d(nn.Module):
482
+ """Equalized Linear as StyleGAN2.
483
+
484
+ Args:
485
+ in_channels (int): Channel number of the input.
486
+ out_channels (int): Channel number of the output.
487
+ kernel_size (int): Size of the convolving kernel.
488
+ stride (int): Stride of the convolution. Default: 1
489
+ padding (int): Zero-padding added to both sides of the input.
490
+ Default: 0.
491
+ bias (bool): If ``True``, adds a learnable bias to the output.
492
+ Default: ``True``.
493
+ bias_init_val (float): Bias initialized value. Default: 0.
494
+ """
495
+
496
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
497
+ super(EqualConv2d, self).__init__()
498
+ self.in_channels = in_channels
499
+ self.out_channels = out_channels
500
+ self.kernel_size = kernel_size
501
+ self.stride = stride
502
+ self.padding = padding
503
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
504
+
505
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
506
+ if bias:
507
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
508
+ else:
509
+ self.register_parameter('bias', None)
510
+
511
+ def forward(self, x):
512
+ out = F.conv2d(
513
+ x,
514
+ self.weight * self.scale,
515
+ bias=self.bias,
516
+ stride=self.stride,
517
+ padding=self.padding,
518
+ )
519
+
520
+ return out
521
+
522
+ def __repr__(self):
523
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
524
+ f'out_channels={self.out_channels}, '
525
+ f'kernel_size={self.kernel_size},'
526
+ f' stride={self.stride}, padding={self.padding}, '
527
+ f'bias={self.bias is not None})')
528
+
529
+
530
+ class ConvLayer(nn.Sequential):
531
+ """Conv Layer used in StyleGAN2 Discriminator.
532
+
533
+ Args:
534
+ in_channels (int): Channel number of the input.
535
+ out_channels (int): Channel number of the output.
536
+ kernel_size (int): Kernel size.
537
+ downsample (bool): Whether downsample by a factor of 2.
538
+ Default: False.
539
+ bias (bool): Whether with bias. Default: True.
540
+ activate (bool): Whether use activateion. Default: True.
541
+ """
542
+
543
+ def __init__(self,
544
+ in_channels,
545
+ out_channels,
546
+ kernel_size,
547
+ downsample=False,
548
+ bias=True,
549
+ activate=True,
550
+ interpolation_mode='bilinear'):
551
+ layers = []
552
+ self.interpolation_mode = interpolation_mode
553
+ # downsample
554
+ if downsample:
555
+ if self.interpolation_mode == 'nearest':
556
+ self.align_corners = None
557
+ else:
558
+ self.align_corners = False
559
+
560
+ layers.append(
561
+ torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
562
+ stride = 1
563
+ self.padding = kernel_size // 2
564
+ # conv
565
+ layers.append(
566
+ EqualConv2d(
567
+ in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
568
+ and not activate))
569
+ # activation
570
+ if activate:
571
+ if bias:
572
+ layers.append(FusedLeakyReLU(out_channels))
573
+ else:
574
+ layers.append(ScaledLeakyReLU(0.2))
575
+
576
+ super(ConvLayer, self).__init__(*layers)
577
+
578
+
579
+ class ResBlock(nn.Module):
580
+ """Residual block used in StyleGAN2 Discriminator.
581
+
582
+ Args:
583
+ in_channels (int): Channel number of the input.
584
+ out_channels (int): Channel number of the output.
585
+ """
586
+
587
+ def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
588
+ super(ResBlock, self).__init__()
589
+
590
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
591
+ self.conv2 = ConvLayer(
592
+ in_channels,
593
+ out_channels,
594
+ 3,
595
+ downsample=True,
596
+ interpolation_mode=interpolation_mode,
597
+ bias=True,
598
+ activate=True)
599
+ self.skip = ConvLayer(
600
+ in_channels,
601
+ out_channels,
602
+ 1,
603
+ downsample=True,
604
+ interpolation_mode=interpolation_mode,
605
+ bias=False,
606
+ activate=False)
607
+
608
+ def forward(self, x):
609
+ out = self.conv1(x)
610
+ out = self.conv2(out)
611
+ skip = self.skip(x)
612
+ out = (out + skip) / math.sqrt(2)
613
+ return out
gfpgan/archs/stylegan2_clean_arch.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from basicsr.archs.arch_util import default_init_weights
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class NormStyleCode(nn.Module):
11
+
12
+ def forward(self, x):
13
+ """Normalize the style codes.
14
+
15
+ Args:
16
+ x (Tensor): Style codes with shape (b, c).
17
+
18
+ Returns:
19
+ Tensor: Normalized tensor.
20
+ """
21
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22
+
23
+
24
+ class ModulatedConv2d(nn.Module):
25
+ """Modulated Conv2d used in StyleGAN2.
26
+
27
+ There is no bias in ModulatedConv2d.
28
+
29
+ Args:
30
+ in_channels (int): Channel number of the input.
31
+ out_channels (int): Channel number of the output.
32
+ kernel_size (int): Size of the convolving kernel.
33
+ num_style_feat (int): Channel number of style features.
34
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
35
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
36
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
37
+ """
38
+
39
+ def __init__(self,
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ num_style_feat,
44
+ demodulate=True,
45
+ sample_mode=None,
46
+ eps=1e-8):
47
+ super(ModulatedConv2d, self).__init__()
48
+ self.in_channels = in_channels
49
+ self.out_channels = out_channels
50
+ self.kernel_size = kernel_size
51
+ self.demodulate = demodulate
52
+ self.sample_mode = sample_mode
53
+ self.eps = eps
54
+
55
+ # modulation inside each modulated conv
56
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
57
+ # initialization
58
+ default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
59
+
60
+ self.weight = nn.Parameter(
61
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
62
+ math.sqrt(in_channels * kernel_size**2))
63
+ self.padding = kernel_size // 2
64
+
65
+ def forward(self, x, style):
66
+ """Forward function.
67
+
68
+ Args:
69
+ x (Tensor): Tensor with shape (b, c, h, w).
70
+ style (Tensor): Tensor with shape (b, num_style_feat).
71
+
72
+ Returns:
73
+ Tensor: Modulated tensor after convolution.
74
+ """
75
+ b, c, h, w = x.shape # c = c_in
76
+ # weight modulation
77
+ style = self.modulation(style).view(b, 1, c, 1, 1)
78
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
79
+ weight = self.weight * style # (b, c_out, c_in, k, k)
80
+
81
+ if self.demodulate:
82
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
83
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
84
+
85
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
86
+
87
+ # upsample or downsample if necessary
88
+ if self.sample_mode == 'upsample':
89
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
90
+ elif self.sample_mode == 'downsample':
91
+ x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
92
+
93
+ b, c, h, w = x.shape
94
+ x = x.view(1, b * c, h, w)
95
+ # weight: (b*c_out, c_in, k, k), groups=b
96
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
97
+ out = out.view(b, self.out_channels, *out.shape[2:4])
98
+
99
+ return out
100
+
101
+ def __repr__(self):
102
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
103
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
104
+
105
+
106
+ class StyleConv(nn.Module):
107
+ """Style conv used in StyleGAN2.
108
+
109
+ Args:
110
+ in_channels (int): Channel number of the input.
111
+ out_channels (int): Channel number of the output.
112
+ kernel_size (int): Size of the convolving kernel.
113
+ num_style_feat (int): Channel number of style features.
114
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
115
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
116
+ """
117
+
118
+ def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
119
+ super(StyleConv, self).__init__()
120
+ self.modulated_conv = ModulatedConv2d(
121
+ in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
122
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
123
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
124
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
125
+
126
+ def forward(self, x, style, noise=None):
127
+ # modulate
128
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
129
+ # noise injection
130
+ if noise is None:
131
+ b, _, h, w = out.shape
132
+ noise = out.new_empty(b, 1, h, w).normal_()
133
+ out = out + self.weight * noise
134
+ # add bias
135
+ out = out + self.bias
136
+ # activation
137
+ out = self.activate(out)
138
+ return out
139
+
140
+
141
+ class ToRGB(nn.Module):
142
+ """To RGB (image space) from features.
143
+
144
+ Args:
145
+ in_channels (int): Channel number of input.
146
+ num_style_feat (int): Channel number of style features.
147
+ upsample (bool): Whether to upsample. Default: True.
148
+ """
149
+
150
+ def __init__(self, in_channels, num_style_feat, upsample=True):
151
+ super(ToRGB, self).__init__()
152
+ self.upsample = upsample
153
+ self.modulated_conv = ModulatedConv2d(
154
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
155
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
156
+
157
+ def forward(self, x, style, skip=None):
158
+ """Forward function.
159
+
160
+ Args:
161
+ x (Tensor): Feature tensor with shape (b, c, h, w).
162
+ style (Tensor): Tensor with shape (b, num_style_feat).
163
+ skip (Tensor): Base/skip tensor. Default: None.
164
+
165
+ Returns:
166
+ Tensor: RGB images.
167
+ """
168
+ out = self.modulated_conv(x, style)
169
+ out = out + self.bias
170
+ if skip is not None:
171
+ if self.upsample:
172
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
173
+ out = out + skip
174
+ return out
175
+
176
+
177
+ class ConstantInput(nn.Module):
178
+ """Constant input.
179
+
180
+ Args:
181
+ num_channel (int): Channel number of constant input.
182
+ size (int): Spatial size of constant input.
183
+ """
184
+
185
+ def __init__(self, num_channel, size):
186
+ super(ConstantInput, self).__init__()
187
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
188
+
189
+ def forward(self, batch):
190
+ out = self.weight.repeat(batch, 1, 1, 1)
191
+ return out
192
+
193
+
194
+ @ARCH_REGISTRY.register()
195
+ class StyleGAN2GeneratorClean(nn.Module):
196
+ """Clean version of StyleGAN2 Generator.
197
+
198
+ Args:
199
+ out_size (int): The spatial size of outputs.
200
+ num_style_feat (int): Channel number of style features. Default: 512.
201
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
202
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
203
+ narrow (float): Narrow ratio for channels. Default: 1.0.
204
+ """
205
+
206
+ def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
207
+ super(StyleGAN2GeneratorClean, self).__init__()
208
+ # Style MLP layers
209
+ self.num_style_feat = num_style_feat
210
+ style_mlp_layers = [NormStyleCode()]
211
+ for i in range(num_mlp):
212
+ style_mlp_layers.extend(
213
+ [nn.Linear(num_style_feat, num_style_feat, bias=True),
214
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)])
215
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
216
+ # initialization
217
+ default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
218
+
219
+ # channel list
220
+ channels = {
221
+ '4': int(512 * narrow),
222
+ '8': int(512 * narrow),
223
+ '16': int(512 * narrow),
224
+ '32': int(512 * narrow),
225
+ '64': int(256 * channel_multiplier * narrow),
226
+ '128': int(128 * channel_multiplier * narrow),
227
+ '256': int(64 * channel_multiplier * narrow),
228
+ '512': int(32 * channel_multiplier * narrow),
229
+ '1024': int(16 * channel_multiplier * narrow)
230
+ }
231
+ self.channels = channels
232
+
233
+ self.constant_input = ConstantInput(channels['4'], size=4)
234
+ self.style_conv1 = StyleConv(
235
+ channels['4'],
236
+ channels['4'],
237
+ kernel_size=3,
238
+ num_style_feat=num_style_feat,
239
+ demodulate=True,
240
+ sample_mode=None)
241
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
242
+
243
+ self.log_size = int(math.log(out_size, 2))
244
+ self.num_layers = (self.log_size - 2) * 2 + 1
245
+ self.num_latent = self.log_size * 2 - 2
246
+
247
+ self.style_convs = nn.ModuleList()
248
+ self.to_rgbs = nn.ModuleList()
249
+ self.noises = nn.Module()
250
+
251
+ in_channels = channels['4']
252
+ # noise
253
+ for layer_idx in range(self.num_layers):
254
+ resolution = 2**((layer_idx + 5) // 2)
255
+ shape = [1, 1, resolution, resolution]
256
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
257
+ # style convs and to_rgbs
258
+ for i in range(3, self.log_size + 1):
259
+ out_channels = channels[f'{2**i}']
260
+ self.style_convs.append(
261
+ StyleConv(
262
+ in_channels,
263
+ out_channels,
264
+ kernel_size=3,
265
+ num_style_feat=num_style_feat,
266
+ demodulate=True,
267
+ sample_mode='upsample'))
268
+ self.style_convs.append(
269
+ StyleConv(
270
+ out_channels,
271
+ out_channels,
272
+ kernel_size=3,
273
+ num_style_feat=num_style_feat,
274
+ demodulate=True,
275
+ sample_mode=None))
276
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
277
+ in_channels = out_channels
278
+
279
+ def make_noise(self):
280
+ """Make noise for noise injection."""
281
+ device = self.constant_input.weight.device
282
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
283
+
284
+ for i in range(3, self.log_size + 1):
285
+ for _ in range(2):
286
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
287
+
288
+ return noises
289
+
290
+ def get_latent(self, x):
291
+ return self.style_mlp(x)
292
+
293
+ def mean_latent(self, num_latent):
294
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
295
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
296
+ return latent
297
+
298
+ def forward(self,
299
+ styles,
300
+ input_is_latent=False,
301
+ noise=None,
302
+ randomize_noise=True,
303
+ truncation=1,
304
+ truncation_latent=None,
305
+ inject_index=None,
306
+ return_latents=False):
307
+ """Forward function for StyleGAN2GeneratorClean.
308
+
309
+ Args:
310
+ styles (list[Tensor]): Sample codes of styles.
311
+ input_is_latent (bool): Whether input is latent style. Default: False.
312
+ noise (Tensor | None): Input noise or None. Default: None.
313
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
314
+ truncation (float): The truncation ratio. Default: 1.
315
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
316
+ inject_index (int | None): The injection index for mixing noise. Default: None.
317
+ return_latents (bool): Whether to return style latents. Default: False.
318
+ """
319
+ # style codes -> latents with Style MLP layer
320
+ if not input_is_latent:
321
+ styles = [self.style_mlp(s) for s in styles]
322
+ # noises
323
+ if noise is None:
324
+ if randomize_noise:
325
+ noise = [None] * self.num_layers # for each style conv layer
326
+ else: # use the stored noise
327
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
328
+ # style truncation
329
+ if truncation < 1:
330
+ style_truncation = []
331
+ for style in styles:
332
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
333
+ styles = style_truncation
334
+ # get style latents with injection
335
+ if len(styles) == 1:
336
+ inject_index = self.num_latent
337
+
338
+ if styles[0].ndim < 3:
339
+ # repeat latent code for all the layers
340
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
341
+ else: # used for encoder with different latent code for each layer
342
+ latent = styles[0]
343
+ elif len(styles) == 2: # mixing noises
344
+ if inject_index is None:
345
+ inject_index = random.randint(1, self.num_latent - 1)
346
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
347
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
348
+ latent = torch.cat([latent1, latent2], 1)
349
+
350
+ # main generation
351
+ out = self.constant_input(latent.shape[0])
352
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
353
+ skip = self.to_rgb1(out, latent[:, 1])
354
+
355
+ i = 1
356
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
357
+ noise[2::2], self.to_rgbs):
358
+ out = conv1(out, latent[:, i], noise=noise1)
359
+ out = conv2(out, latent[:, i + 1], noise=noise2)
360
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
361
+ i += 2
362
+
363
+ image = skip
364
+
365
+ if return_latents:
366
+ return image, latent
367
+ else:
368
+ return image, None
gfpgan/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import dataset modules for registry
6
+ # scan all the files that end with '_dataset.py' under the data folder
7
+ data_folder = osp.dirname(osp.abspath(__file__))
8
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
9
+ # import all the dataset modules
10
+ _dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
gfpgan/data/ffhq_degradation_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os.path as osp
5
+ import torch
6
+ import torch.utils.data as data
7
+ from basicsr.data import degradations as degradations
8
+ from basicsr.data.data_util import paths_from_folder
9
+ from basicsr.data.transforms import augment
10
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
11
+ from basicsr.utils.registry import DATASET_REGISTRY
12
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
13
+ normalize)
14
+
15
+
16
+ @DATASET_REGISTRY.register()
17
+ class FFHQDegradationDataset(data.Dataset):
18
+ """FFHQ dataset for GFPGAN.
19
+
20
+ It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
21
+
22
+ Args:
23
+ opt (dict): Config for train datasets. It contains the following keys:
24
+ dataroot_gt (str): Data root path for gt.
25
+ io_backend (dict): IO backend type and other kwarg.
26
+ mean (list | tuple): Image mean.
27
+ std (list | tuple): Image std.
28
+ use_hflip (bool): Whether to horizontally flip.
29
+ Please see more options in the codes.
30
+ """
31
+
32
+ def __init__(self, opt):
33
+ super(FFHQDegradationDataset, self).__init__()
34
+ self.opt = opt
35
+ # file client (io backend)
36
+ self.file_client = None
37
+ self.io_backend_opt = opt['io_backend']
38
+
39
+ self.gt_folder = opt['dataroot_gt']
40
+ self.mean = opt['mean']
41
+ self.std = opt['std']
42
+ self.out_size = opt['out_size']
43
+
44
+ self.crop_components = opt.get('crop_components', False) # facial components
45
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
46
+
47
+ if self.crop_components:
48
+ # load component list from a pre-process pth files
49
+ self.components_list = torch.load(opt.get('component_path'))
50
+
51
+ # file client (lmdb io backend)
52
+ if self.io_backend_opt['type'] == 'lmdb':
53
+ self.io_backend_opt['db_paths'] = self.gt_folder
54
+ if not self.gt_folder.endswith('.lmdb'):
55
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
56
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
57
+ self.paths = [line.split('.')[0] for line in fin]
58
+ else:
59
+ # disk backend: scan file list from a folder
60
+ self.paths = paths_from_folder(self.gt_folder)
61
+
62
+ # degradation configurations
63
+ self.blur_kernel_size = opt['blur_kernel_size']
64
+ self.kernel_list = opt['kernel_list']
65
+ self.kernel_prob = opt['kernel_prob']
66
+ self.blur_sigma = opt['blur_sigma']
67
+ self.downsample_range = opt['downsample_range']
68
+ self.noise_range = opt['noise_range']
69
+ self.jpeg_range = opt['jpeg_range']
70
+
71
+ # color jitter
72
+ self.color_jitter_prob = opt.get('color_jitter_prob')
73
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
74
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
75
+ # to gray
76
+ self.gray_prob = opt.get('gray_prob')
77
+
78
+ logger = get_root_logger()
79
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
80
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
81
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
82
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
83
+
84
+ if self.color_jitter_prob is not None:
85
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
86
+ if self.gray_prob is not None:
87
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
88
+ self.color_jitter_shift /= 255.
89
+
90
+ @staticmethod
91
+ def color_jitter(img, shift):
92
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
93
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
94
+ img = img + jitter_val
95
+ img = np.clip(img, 0, 1)
96
+ return img
97
+
98
+ @staticmethod
99
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
100
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
101
+ fn_idx = torch.randperm(4)
102
+ for fn_id in fn_idx:
103
+ if fn_id == 0 and brightness is not None:
104
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
105
+ img = adjust_brightness(img, brightness_factor)
106
+
107
+ if fn_id == 1 and contrast is not None:
108
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
109
+ img = adjust_contrast(img, contrast_factor)
110
+
111
+ if fn_id == 2 and saturation is not None:
112
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
113
+ img = adjust_saturation(img, saturation_factor)
114
+
115
+ if fn_id == 3 and hue is not None:
116
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
117
+ img = adjust_hue(img, hue_factor)
118
+ return img
119
+
120
+ def get_component_coordinates(self, index, status):
121
+ """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
122
+ components_bbox = self.components_list[f'{index:08d}']
123
+ if status[0]: # hflip
124
+ # exchange right and left eye
125
+ tmp = components_bbox['left_eye']
126
+ components_bbox['left_eye'] = components_bbox['right_eye']
127
+ components_bbox['right_eye'] = tmp
128
+ # modify the width coordinate
129
+ components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
130
+ components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
131
+ components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
132
+
133
+ # get coordinates
134
+ locations = []
135
+ for part in ['left_eye', 'right_eye', 'mouth']:
136
+ mean = components_bbox[part][0:2]
137
+ half_len = components_bbox[part][2]
138
+ if 'eye' in part:
139
+ half_len *= self.eye_enlarge_ratio
140
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
141
+ loc = torch.from_numpy(loc).float()
142
+ locations.append(loc)
143
+ return locations
144
+
145
+ def __getitem__(self, index):
146
+ if self.file_client is None:
147
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
148
+
149
+ # load gt image
150
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
151
+ gt_path = self.paths[index]
152
+ img_bytes = self.file_client.get(gt_path)
153
+ img_gt = imfrombytes(img_bytes, float32=True)
154
+
155
+ # random horizontal flip
156
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
157
+ h, w, _ = img_gt.shape
158
+
159
+ # get facial component coordinates
160
+ if self.crop_components:
161
+ locations = self.get_component_coordinates(index, status)
162
+ loc_left_eye, loc_right_eye, loc_mouth = locations
163
+
164
+ # ------------------------ generate lq image ------------------------ #
165
+ # blur
166
+ kernel = degradations.random_mixed_kernels(
167
+ self.kernel_list,
168
+ self.kernel_prob,
169
+ self.blur_kernel_size,
170
+ self.blur_sigma,
171
+ self.blur_sigma, [-math.pi, math.pi],
172
+ noise_range=None)
173
+ img_lq = cv2.filter2D(img_gt, -1, kernel)
174
+ # downsample
175
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
176
+ img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
177
+ # noise
178
+ if self.noise_range is not None:
179
+ img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
180
+ # jpeg compression
181
+ if self.jpeg_range is not None:
182
+ img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
183
+
184
+ # resize to original size
185
+ img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
186
+
187
+ # random color jitter (only for lq)
188
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
189
+ img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
190
+ # random to gray (only for lq)
191
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
192
+ img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
193
+ img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
194
+ if self.opt.get('gt_gray'): # whether convert GT to gray images
195
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
196
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
197
+
198
+ # BGR to RGB, HWC to CHW, numpy to tensor
199
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
200
+
201
+ # random color jitter (pytorch version) (only for lq)
202
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
203
+ brightness = self.opt.get('brightness', (0.5, 1.5))
204
+ contrast = self.opt.get('contrast', (0.5, 1.5))
205
+ saturation = self.opt.get('saturation', (0, 1.5))
206
+ hue = self.opt.get('hue', (-0.1, 0.1))
207
+ img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
208
+
209
+ # round and clip
210
+ img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
211
+
212
+ # normalize
213
+ normalize(img_gt, self.mean, self.std, inplace=True)
214
+ normalize(img_lq, self.mean, self.std, inplace=True)
215
+
216
+ if self.crop_components:
217
+ return_dict = {
218
+ 'lq': img_lq,
219
+ 'gt': img_gt,
220
+ 'gt_path': gt_path,
221
+ 'loc_left_eye': loc_left_eye,
222
+ 'loc_right_eye': loc_right_eye,
223
+ 'loc_mouth': loc_mouth
224
+ }
225
+ return return_dict
226
+ else:
227
+ return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
228
+
229
+ def __len__(self):
230
+ return len(self.paths)
gfpgan/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import model modules for registry
6
+ # scan all the files that end with '_model.py' under the model folder
7
+ model_folder = osp.dirname(osp.abspath(__file__))
8
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
9
+ # import all the model modules
10
+ _model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
gfpgan/models/gfpgan_model.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os.path as osp
3
+ import torch
4
+ from basicsr.archs import build_network
5
+ from basicsr.losses import build_loss
6
+ from basicsr.losses.gan_loss import r1_penalty
7
+ from basicsr.metrics import calculate_metric
8
+ from basicsr.models.base_model import BaseModel
9
+ from basicsr.utils import get_root_logger, imwrite, tensor2img
10
+ from basicsr.utils.registry import MODEL_REGISTRY
11
+ from collections import OrderedDict
12
+ from torch.nn import functional as F
13
+ from torchvision.ops import roi_align
14
+ from tqdm import tqdm
15
+
16
+
17
+ @MODEL_REGISTRY.register()
18
+ class GFPGANModel(BaseModel):
19
+ """The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
20
+
21
+ def __init__(self, opt):
22
+ super(GFPGANModel, self).__init__(opt)
23
+ self.idx = 0 # it is used for saving data for check
24
+
25
+ # define network
26
+ self.net_g = build_network(opt['network_g'])
27
+ self.net_g = self.model_to_device(self.net_g)
28
+ self.print_network(self.net_g)
29
+
30
+ # load pretrained model
31
+ load_path = self.opt['path'].get('pretrain_network_g', None)
32
+ if load_path is not None:
33
+ param_key = self.opt['path'].get('param_key_g', 'params')
34
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
35
+
36
+ self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
37
+
38
+ if self.is_train:
39
+ self.init_training_settings()
40
+
41
+ def init_training_settings(self):
42
+ train_opt = self.opt['train']
43
+
44
+ # ----------- define net_d ----------- #
45
+ self.net_d = build_network(self.opt['network_d'])
46
+ self.net_d = self.model_to_device(self.net_d)
47
+ self.print_network(self.net_d)
48
+ # load pretrained model
49
+ load_path = self.opt['path'].get('pretrain_network_d', None)
50
+ if load_path is not None:
51
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
52
+
53
+ # ----------- define net_g with Exponential Moving Average (EMA) ----------- #
54
+ # net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
55
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
56
+ # load pretrained model
57
+ load_path = self.opt['path'].get('pretrain_network_g', None)
58
+ if load_path is not None:
59
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
60
+ else:
61
+ self.model_ema(0) # copy net_g weight
62
+
63
+ self.net_g.train()
64
+ self.net_d.train()
65
+ self.net_g_ema.eval()
66
+
67
+ # ----------- facial component networks ----------- #
68
+ if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
69
+ self.use_facial_disc = True
70
+ else:
71
+ self.use_facial_disc = False
72
+
73
+ if self.use_facial_disc:
74
+ # left eye
75
+ self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
76
+ self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
77
+ self.print_network(self.net_d_left_eye)
78
+ load_path = self.opt['path'].get('pretrain_network_d_left_eye')
79
+ if load_path is not None:
80
+ self.load_network(self.net_d_left_eye, load_path, True, 'params')
81
+ # right eye
82
+ self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
83
+ self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
84
+ self.print_network(self.net_d_right_eye)
85
+ load_path = self.opt['path'].get('pretrain_network_d_right_eye')
86
+ if load_path is not None:
87
+ self.load_network(self.net_d_right_eye, load_path, True, 'params')
88
+ # mouth
89
+ self.net_d_mouth = build_network(self.opt['network_d_mouth'])
90
+ self.net_d_mouth = self.model_to_device(self.net_d_mouth)
91
+ self.print_network(self.net_d_mouth)
92
+ load_path = self.opt['path'].get('pretrain_network_d_mouth')
93
+ if load_path is not None:
94
+ self.load_network(self.net_d_mouth, load_path, True, 'params')
95
+
96
+ self.net_d_left_eye.train()
97
+ self.net_d_right_eye.train()
98
+ self.net_d_mouth.train()
99
+
100
+ # ----------- define facial component gan loss ----------- #
101
+ self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
102
+
103
+ # ----------- define losses ----------- #
104
+ # pixel loss
105
+ if train_opt.get('pixel_opt'):
106
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
107
+ else:
108
+ self.cri_pix = None
109
+
110
+ # perceptual loss
111
+ if train_opt.get('perceptual_opt'):
112
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
113
+ else:
114
+ self.cri_perceptual = None
115
+
116
+ # L1 loss is used in pyramid loss, component style loss and identity loss
117
+ self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
118
+
119
+ # gan loss (wgan)
120
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
121
+
122
+ # ----------- define identity loss ----------- #
123
+ if 'network_identity' in self.opt:
124
+ self.use_identity = True
125
+ else:
126
+ self.use_identity = False
127
+
128
+ if self.use_identity:
129
+ # define identity network
130
+ self.network_identity = build_network(self.opt['network_identity'])
131
+ self.network_identity = self.model_to_device(self.network_identity)
132
+ self.print_network(self.network_identity)
133
+ load_path = self.opt['path'].get('pretrain_network_identity')
134
+ if load_path is not None:
135
+ self.load_network(self.network_identity, load_path, True, None)
136
+ self.network_identity.eval()
137
+ for param in self.network_identity.parameters():
138
+ param.requires_grad = False
139
+
140
+ # regularization weights
141
+ self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
142
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
143
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
144
+ self.net_d_reg_every = train_opt['net_d_reg_every']
145
+
146
+ # set up optimizers and schedulers
147
+ self.setup_optimizers()
148
+ self.setup_schedulers()
149
+
150
+ def setup_optimizers(self):
151
+ train_opt = self.opt['train']
152
+
153
+ # ----------- optimizer g ----------- #
154
+ net_g_reg_ratio = 1
155
+ normal_params = []
156
+ for _, param in self.net_g.named_parameters():
157
+ normal_params.append(param)
158
+ optim_params_g = [{ # add normal params first
159
+ 'params': normal_params,
160
+ 'lr': train_opt['optim_g']['lr']
161
+ }]
162
+ optim_type = train_opt['optim_g'].pop('type')
163
+ lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
164
+ betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
165
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
166
+ self.optimizers.append(self.optimizer_g)
167
+
168
+ # ----------- optimizer d ----------- #
169
+ net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
170
+ normal_params = []
171
+ for _, param in self.net_d.named_parameters():
172
+ normal_params.append(param)
173
+ optim_params_d = [{ # add normal params first
174
+ 'params': normal_params,
175
+ 'lr': train_opt['optim_d']['lr']
176
+ }]
177
+ optim_type = train_opt['optim_d'].pop('type')
178
+ lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
179
+ betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
180
+ self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
181
+ self.optimizers.append(self.optimizer_d)
182
+
183
+ # ----------- optimizers for facial component networks ----------- #
184
+ if self.use_facial_disc:
185
+ # setup optimizers for facial component discriminators
186
+ optim_type = train_opt['optim_component'].pop('type')
187
+ lr = train_opt['optim_component']['lr']
188
+ # left eye
189
+ self.optimizer_d_left_eye = self.get_optimizer(
190
+ optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
191
+ self.optimizers.append(self.optimizer_d_left_eye)
192
+ # right eye
193
+ self.optimizer_d_right_eye = self.get_optimizer(
194
+ optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
195
+ self.optimizers.append(self.optimizer_d_right_eye)
196
+ # mouth
197
+ self.optimizer_d_mouth = self.get_optimizer(
198
+ optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
199
+ self.optimizers.append(self.optimizer_d_mouth)
200
+
201
+ def feed_data(self, data):
202
+ self.lq = data['lq'].to(self.device)
203
+ if 'gt' in data:
204
+ self.gt = data['gt'].to(self.device)
205
+
206
+ if 'loc_left_eye' in data:
207
+ # get facial component locations, shape (batch, 4)
208
+ self.loc_left_eyes = data['loc_left_eye']
209
+ self.loc_right_eyes = data['loc_right_eye']
210
+ self.loc_mouths = data['loc_mouth']
211
+
212
+ # uncomment to check data
213
+ # import torchvision
214
+ # if self.opt['rank'] == 0:
215
+ # import os
216
+ # os.makedirs('tmp/gt', exist_ok=True)
217
+ # os.makedirs('tmp/lq', exist_ok=True)
218
+ # print(self.idx)
219
+ # torchvision.utils.save_image(
220
+ # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
221
+ # torchvision.utils.save_image(
222
+ # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
223
+ # self.idx = self.idx + 1
224
+
225
+ def construct_img_pyramid(self):
226
+ """Construct image pyramid for intermediate restoration loss"""
227
+ pyramid_gt = [self.gt]
228
+ down_img = self.gt
229
+ for _ in range(0, self.log_size - 3):
230
+ down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
231
+ pyramid_gt.insert(0, down_img)
232
+ return pyramid_gt
233
+
234
+ def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
235
+ face_ratio = int(self.opt['network_g']['out_size'] / 512)
236
+ eye_out_size *= face_ratio
237
+ mouth_out_size *= face_ratio
238
+
239
+ rois_eyes = []
240
+ rois_mouths = []
241
+ for b in range(self.loc_left_eyes.size(0)): # loop for batch size
242
+ # left eye and right eye
243
+ img_inds = self.loc_left_eyes.new_full((2, 1), b)
244
+ bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
245
+ rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
246
+ rois_eyes.append(rois)
247
+ # mouse
248
+ img_inds = self.loc_left_eyes.new_full((1, 1), b)
249
+ rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
250
+ rois_mouths.append(rois)
251
+
252
+ rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
253
+ rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
254
+
255
+ # real images
256
+ all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
257
+ self.left_eyes_gt = all_eyes[0::2, :, :, :]
258
+ self.right_eyes_gt = all_eyes[1::2, :, :, :]
259
+ self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
260
+ # output
261
+ all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
262
+ self.left_eyes = all_eyes[0::2, :, :, :]
263
+ self.right_eyes = all_eyes[1::2, :, :, :]
264
+ self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
265
+
266
+ def _gram_mat(self, x):
267
+ """Calculate Gram matrix.
268
+
269
+ Args:
270
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
271
+
272
+ Returns:
273
+ torch.Tensor: Gram matrix.
274
+ """
275
+ n, c, h, w = x.size()
276
+ features = x.view(n, c, w * h)
277
+ features_t = features.transpose(1, 2)
278
+ gram = features.bmm(features_t) / (c * h * w)
279
+ return gram
280
+
281
+ def gray_resize_for_identity(self, out, size=128):
282
+ out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
283
+ out_gray = out_gray.unsqueeze(1)
284
+ out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
285
+ return out_gray
286
+
287
+ def optimize_parameters(self, current_iter):
288
+ # optimize net_g
289
+ for p in self.net_d.parameters():
290
+ p.requires_grad = False
291
+ self.optimizer_g.zero_grad()
292
+
293
+ # do not update facial component net_d
294
+ if self.use_facial_disc:
295
+ for p in self.net_d_left_eye.parameters():
296
+ p.requires_grad = False
297
+ for p in self.net_d_right_eye.parameters():
298
+ p.requires_grad = False
299
+ for p in self.net_d_mouth.parameters():
300
+ p.requires_grad = False
301
+
302
+ # image pyramid loss weight
303
+ pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
304
+ if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
305
+ pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
306
+ if pyramid_loss_weight > 0:
307
+ self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
308
+ pyramid_gt = self.construct_img_pyramid()
309
+ else:
310
+ self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
311
+
312
+ # get roi-align regions
313
+ if self.use_facial_disc:
314
+ self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
315
+
316
+ l_g_total = 0
317
+ loss_dict = OrderedDict()
318
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
319
+ # pixel loss
320
+ if self.cri_pix:
321
+ l_g_pix = self.cri_pix(self.output, self.gt)
322
+ l_g_total += l_g_pix
323
+ loss_dict['l_g_pix'] = l_g_pix
324
+
325
+ # image pyramid loss
326
+ if pyramid_loss_weight > 0:
327
+ for i in range(0, self.log_size - 2):
328
+ l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
329
+ l_g_total += l_pyramid
330
+ loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
331
+
332
+ # perceptual loss
333
+ if self.cri_perceptual:
334
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
335
+ if l_g_percep is not None:
336
+ l_g_total += l_g_percep
337
+ loss_dict['l_g_percep'] = l_g_percep
338
+ if l_g_style is not None:
339
+ l_g_total += l_g_style
340
+ loss_dict['l_g_style'] = l_g_style
341
+
342
+ # gan loss
343
+ fake_g_pred = self.net_d(self.output)
344
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
345
+ l_g_total += l_g_gan
346
+ loss_dict['l_g_gan'] = l_g_gan
347
+
348
+ # facial component loss
349
+ if self.use_facial_disc:
350
+ # left eye
351
+ fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
352
+ l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
353
+ l_g_total += l_g_gan
354
+ loss_dict['l_g_gan_left_eye'] = l_g_gan
355
+ # right eye
356
+ fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
357
+ l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
358
+ l_g_total += l_g_gan
359
+ loss_dict['l_g_gan_right_eye'] = l_g_gan
360
+ # mouth
361
+ fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
362
+ l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
363
+ l_g_total += l_g_gan
364
+ loss_dict['l_g_gan_mouth'] = l_g_gan
365
+
366
+ if self.opt['train'].get('comp_style_weight', 0) > 0:
367
+ # get gt feat
368
+ _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
369
+ _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
370
+ _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
371
+
372
+ def _comp_style(feat, feat_gt, criterion):
373
+ return criterion(self._gram_mat(feat[0]), self._gram_mat(
374
+ feat_gt[0].detach())) * 0.5 + criterion(
375
+ self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
376
+
377
+ # facial component style loss
378
+ comp_style_loss = 0
379
+ comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
380
+ comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
381
+ comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
382
+ comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
383
+ l_g_total += comp_style_loss
384
+ loss_dict['l_g_comp_style_loss'] = comp_style_loss
385
+
386
+ # identity loss
387
+ if self.use_identity:
388
+ identity_weight = self.opt['train']['identity_weight']
389
+ # get gray images and resize
390
+ out_gray = self.gray_resize_for_identity(self.output)
391
+ gt_gray = self.gray_resize_for_identity(self.gt)
392
+
393
+ identity_gt = self.network_identity(gt_gray).detach()
394
+ identity_out = self.network_identity(out_gray)
395
+ l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
396
+ l_g_total += l_identity
397
+ loss_dict['l_identity'] = l_identity
398
+
399
+ l_g_total.backward()
400
+ self.optimizer_g.step()
401
+
402
+ # EMA
403
+ self.model_ema(decay=0.5**(32 / (10 * 1000)))
404
+
405
+ # ----------- optimize net_d ----------- #
406
+ for p in self.net_d.parameters():
407
+ p.requires_grad = True
408
+ self.optimizer_d.zero_grad()
409
+ if self.use_facial_disc:
410
+ for p in self.net_d_left_eye.parameters():
411
+ p.requires_grad = True
412
+ for p in self.net_d_right_eye.parameters():
413
+ p.requires_grad = True
414
+ for p in self.net_d_mouth.parameters():
415
+ p.requires_grad = True
416
+ self.optimizer_d_left_eye.zero_grad()
417
+ self.optimizer_d_right_eye.zero_grad()
418
+ self.optimizer_d_mouth.zero_grad()
419
+
420
+ fake_d_pred = self.net_d(self.output.detach())
421
+ real_d_pred = self.net_d(self.gt)
422
+ l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
423
+ loss_dict['l_d'] = l_d
424
+ # In WGAN, real_score should be positive and fake_score should be negative
425
+ loss_dict['real_score'] = real_d_pred.detach().mean()
426
+ loss_dict['fake_score'] = fake_d_pred.detach().mean()
427
+ l_d.backward()
428
+
429
+ # regularization loss
430
+ if current_iter % self.net_d_reg_every == 0:
431
+ self.gt.requires_grad = True
432
+ real_pred = self.net_d(self.gt)
433
+ l_d_r1 = r1_penalty(real_pred, self.gt)
434
+ l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
435
+ loss_dict['l_d_r1'] = l_d_r1.detach().mean()
436
+ l_d_r1.backward()
437
+
438
+ self.optimizer_d.step()
439
+
440
+ # optimize facial component discriminators
441
+ if self.use_facial_disc:
442
+ # left eye
443
+ fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
444
+ real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
445
+ l_d_left_eye = self.cri_component(
446
+ real_d_pred, True, is_disc=True) + self.cri_gan(
447
+ fake_d_pred, False, is_disc=True)
448
+ loss_dict['l_d_left_eye'] = l_d_left_eye
449
+ l_d_left_eye.backward()
450
+ # right eye
451
+ fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
452
+ real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
453
+ l_d_right_eye = self.cri_component(
454
+ real_d_pred, True, is_disc=True) + self.cri_gan(
455
+ fake_d_pred, False, is_disc=True)
456
+ loss_dict['l_d_right_eye'] = l_d_right_eye
457
+ l_d_right_eye.backward()
458
+ # mouth
459
+ fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
460
+ real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
461
+ l_d_mouth = self.cri_component(
462
+ real_d_pred, True, is_disc=True) + self.cri_gan(
463
+ fake_d_pred, False, is_disc=True)
464
+ loss_dict['l_d_mouth'] = l_d_mouth
465
+ l_d_mouth.backward()
466
+
467
+ self.optimizer_d_left_eye.step()
468
+ self.optimizer_d_right_eye.step()
469
+ self.optimizer_d_mouth.step()
470
+
471
+ self.log_dict = self.reduce_loss_dict(loss_dict)
472
+
473
+ def test(self):
474
+ with torch.no_grad():
475
+ if hasattr(self, 'net_g_ema'):
476
+ self.net_g_ema.eval()
477
+ self.output, _ = self.net_g_ema(self.lq)
478
+ else:
479
+ logger = get_root_logger()
480
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
481
+ self.net_g.eval()
482
+ self.output, _ = self.net_g(self.lq)
483
+ self.net_g.train()
484
+
485
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
486
+ if self.opt['rank'] == 0:
487
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
488
+
489
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
490
+ dataset_name = dataloader.dataset.opt['name']
491
+ with_metrics = self.opt['val'].get('metrics') is not None
492
+ use_pbar = self.opt['val'].get('pbar', False)
493
+
494
+ if with_metrics:
495
+ if not hasattr(self, 'metric_results'): # only execute in the first run
496
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
497
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
498
+ self._initialize_best_metric_results(dataset_name)
499
+ # zero self.metric_results
500
+ self.metric_results = {metric: 0 for metric in self.metric_results}
501
+
502
+ metric_data = dict()
503
+ if use_pbar:
504
+ pbar = tqdm(total=len(dataloader), unit='image')
505
+
506
+ for idx, val_data in enumerate(dataloader):
507
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
508
+ self.feed_data(val_data)
509
+ self.test()
510
+
511
+ sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
512
+ metric_data['img'] = sr_img
513
+ if hasattr(self, 'gt'):
514
+ gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
515
+ metric_data['img2'] = gt_img
516
+ del self.gt
517
+
518
+ # tentative for out of GPU memory
519
+ del self.lq
520
+ del self.output
521
+ torch.cuda.empty_cache()
522
+
523
+ if save_img:
524
+ if self.opt['is_train']:
525
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
526
+ f'{img_name}_{current_iter}.png')
527
+ else:
528
+ if self.opt['val']['suffix']:
529
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
530
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
531
+ else:
532
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
533
+ f'{img_name}_{self.opt["name"]}.png')
534
+ imwrite(sr_img, save_img_path)
535
+
536
+ if with_metrics:
537
+ # calculate metrics
538
+ for name, opt_ in self.opt['val']['metrics'].items():
539
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
540
+ if use_pbar:
541
+ pbar.update(1)
542
+ pbar.set_description(f'Test {img_name}')
543
+ if use_pbar:
544
+ pbar.close()
545
+
546
+ if with_metrics:
547
+ for metric in self.metric_results.keys():
548
+ self.metric_results[metric] /= (idx + 1)
549
+ # update the best metric result
550
+ self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
551
+
552
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
553
+
554
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
555
+ log_str = f'Validation {dataset_name}\n'
556
+ for metric, value in self.metric_results.items():
557
+ log_str += f'\t # {metric}: {value:.4f}'
558
+ if hasattr(self, 'best_metric_results'):
559
+ log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
560
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
561
+ log_str += '\n'
562
+
563
+ logger = get_root_logger()
564
+ logger.info(log_str)
565
+ if tb_logger:
566
+ for metric, value in self.metric_results.items():
567
+ tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
568
+
569
+ def save(self, epoch, current_iter):
570
+ # save net_g and net_d
571
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
572
+ self.save_network(self.net_d, 'net_d', current_iter)
573
+ # save component discriminators
574
+ if self.use_facial_disc:
575
+ self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
576
+ self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
577
+ self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
578
+ # save training state
579
+ self.save_training_state(epoch, current_iter)
gfpgan/train.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import os.path as osp
3
+ from basicsr.train import train_pipeline
4
+
5
+ import gfpgan.archs
6
+ import gfpgan.data
7
+ import gfpgan.models
8
+
9
+ if __name__ == '__main__':
10
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
11
+ train_pipeline(root_path)
gfpgan/utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import torch
4
+ from basicsr.utils import img2tensor, tensor2img
5
+ from basicsr.utils.download_util import load_file_from_url
6
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
7
+ from torchvision.transforms.functional import normalize
8
+
9
+ from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
10
+ from gfpgan.archs.gfpganv1_arch import GFPGANv1
11
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
12
+
13
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+
16
+ class GFPGANer():
17
+ """Helper for restoration with GFPGAN.
18
+
19
+ It will detect and crop faces, and then resize the faces to 512x512.
20
+ GFPGAN is used to restored the resized faces.
21
+ The background is upsampled with the bg_upsampler.
22
+ Finally, the faces will be pasted back to the upsample background image.
23
+
24
+ Args:
25
+ model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
26
+ upscale (float): The upscale of the final output. Default: 2.
27
+ arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
28
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
29
+ bg_upsampler (nn.Module): The upsampler for the background. Default: None.
30
+ """
31
+
32
+ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
33
+ self.upscale = upscale
34
+ self.bg_upsampler = bg_upsampler
35
+
36
+ # initialize model
37
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
38
+ # initialize the GFP-GAN
39
+ if arch == 'clean':
40
+ self.gfpgan = GFPGANv1Clean(
41
+ out_size=512,
42
+ num_style_feat=512,
43
+ channel_multiplier=channel_multiplier,
44
+ decoder_load_path=None,
45
+ fix_decoder=False,
46
+ num_mlp=8,
47
+ input_is_latent=True,
48
+ different_w=True,
49
+ narrow=1,
50
+ sft_half=True)
51
+ elif arch == 'bilinear':
52
+ self.gfpgan = GFPGANBilinear(
53
+ out_size=512,
54
+ num_style_feat=512,
55
+ channel_multiplier=channel_multiplier,
56
+ decoder_load_path=None,
57
+ fix_decoder=False,
58
+ num_mlp=8,
59
+ input_is_latent=True,
60
+ different_w=True,
61
+ narrow=1,
62
+ sft_half=True)
63
+ elif arch == 'original':
64
+ self.gfpgan = GFPGANv1(
65
+ out_size=512,
66
+ num_style_feat=512,
67
+ channel_multiplier=channel_multiplier,
68
+ decoder_load_path=None,
69
+ fix_decoder=True,
70
+ num_mlp=8,
71
+ input_is_latent=True,
72
+ different_w=True,
73
+ narrow=1,
74
+ sft_half=True)
75
+ # initialize face helper
76
+ self.face_helper = FaceRestoreHelper(
77
+ upscale,
78
+ face_size=512,
79
+ crop_ratio=(1, 1),
80
+ det_model='retinaface_resnet50',
81
+ save_ext='png',
82
+ use_parse=True,
83
+ device=self.device)
84
+
85
+ if model_path.startswith('https://'):
86
+ model_path = load_file_from_url(
87
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
88
+ loadnet = torch.load(model_path)
89
+ if 'params_ema' in loadnet:
90
+ keyname = 'params_ema'
91
+ else:
92
+ keyname = 'params'
93
+ self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
94
+ self.gfpgan.eval()
95
+ self.gfpgan = self.gfpgan.to(self.device)
96
+
97
+ @torch.no_grad()
98
+ def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
99
+ self.face_helper.clean_all()
100
+
101
+ if has_aligned: # the inputs are already aligned
102
+ img = cv2.resize(img, (512, 512))
103
+ self.face_helper.cropped_faces = [img]
104
+ else:
105
+ self.face_helper.read_image(img)
106
+ # get face landmarks for each face
107
+ self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
108
+ # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
109
+ # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
110
+ # align and warp each face
111
+ self.face_helper.align_warp_face()
112
+
113
+ # face restoration
114
+ for cropped_face in self.face_helper.cropped_faces:
115
+ # prepare data
116
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
117
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
118
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
119
+
120
+ try:
121
+ output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
122
+ # convert to image
123
+ restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
124
+ except RuntimeError as error:
125
+ print(f'\tFailed inference for GFPGAN: {error}.')
126
+ restored_face = cropped_face
127
+
128
+ restored_face = restored_face.astype('uint8')
129
+ self.face_helper.add_restored_face(restored_face)
130
+
131
+ if not has_aligned and paste_back:
132
+ # upsample the background
133
+ if self.bg_upsampler is not None:
134
+ # Now only support RealESRGAN for upsampling background
135
+ bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
136
+ else:
137
+ bg_img = None
138
+
139
+ self.face_helper.get_inverse_affine(None)
140
+ # paste each restored face to the input image
141
+ restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
142
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
143
+ else:
144
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
gfpgan/weights/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Weights
2
+
3
+ Put the downloaded weights to this folder.
inference_gfpgan.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import numpy as np
5
+ import os
6
+ import torch
7
+ from basicsr.utils import imwrite
8
+
9
+ from gfpgan import GFPGANer
10
+
11
+
12
+ def main():
13
+ """Inference demo for GFPGAN (for users).
14
+ """
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '-i',
18
+ '--input',
19
+ type=str,
20
+ default='inputs/whole_imgs',
21
+ help='Input image or folder. Default: inputs/whole_imgs')
22
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder. Default: results')
23
+ # we use version to select models, which is more user-friendly
24
+ parser.add_argument(
25
+ '-v', '--version', type=str, default='1.3', help='GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3')
26
+ parser.add_argument(
27
+ '-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
28
+
29
+ parser.add_argument(
30
+ '--bg_upsampler', type=str, default='realesrgan', help='background upsampler. Default: realesrgan')
31
+ parser.add_argument(
32
+ '--bg_tile',
33
+ type=int,
34
+ default=400,
35
+ help='Tile size for background sampler, 0 for no tile during testing. Default: 400')
36
+ parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
37
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
38
+ parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
39
+ parser.add_argument(
40
+ '--ext',
41
+ type=str,
42
+ default='auto',
43
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto')
44
+ args = parser.parse_args()
45
+
46
+ args = parser.parse_args()
47
+
48
+ # ------------------------ input & output ------------------------
49
+ if args.input.endswith('/'):
50
+ args.input = args.input[:-1]
51
+ if os.path.isfile(args.input):
52
+ img_list = [args.input]
53
+ else:
54
+ img_list = sorted(glob.glob(os.path.join(args.input, '*')))
55
+
56
+ os.makedirs(args.output, exist_ok=True)
57
+
58
+ # ------------------------ set up background upsampler ------------------------
59
+ if args.bg_upsampler == 'realesrgan':
60
+ if not torch.cuda.is_available(): # CPU
61
+ import warnings
62
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
63
+ 'If you really want to use it, please modify the corresponding codes.')
64
+ bg_upsampler = None
65
+ else:
66
+ from basicsr.archs.rrdbnet_arch import RRDBNet
67
+ from realesrgan import RealESRGANer
68
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
69
+ bg_upsampler = RealESRGANer(
70
+ scale=2,
71
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
72
+ model=model,
73
+ tile=args.bg_tile,
74
+ tile_pad=10,
75
+ pre_pad=0,
76
+ half=True) # need to set False in CPU mode
77
+ else:
78
+ bg_upsampler = None
79
+
80
+ # ------------------------ set up GFPGAN restorer ------------------------
81
+ if args.version == '1':
82
+ arch = 'original'
83
+ channel_multiplier = 1
84
+ model_name = 'GFPGANv1'
85
+ elif args.version == '1.2':
86
+ arch = 'clean'
87
+ channel_multiplier = 2
88
+ model_name = 'GFPGANCleanv1-NoCE-C2'
89
+ elif args.version == '1.3':
90
+ arch = 'clean'
91
+ channel_multiplier = 2
92
+ model_name = 'GFPGANv1.3'
93
+ else:
94
+ raise ValueError(f'Wrong model version {args.version}.')
95
+
96
+ # determine model paths
97
+ model_path = os.path.join('experiments/pretrained_models', model_name + '.pth')
98
+ if not os.path.isfile(model_path):
99
+ model_path = os.path.join('realesrgan/weights', model_name + '.pth')
100
+ if not os.path.isfile(model_path):
101
+ raise ValueError(f'Model {model_name} does not exist.')
102
+
103
+ restorer = GFPGANer(
104
+ model_path=model_path,
105
+ upscale=args.upscale,
106
+ arch=arch,
107
+ channel_multiplier=channel_multiplier,
108
+ bg_upsampler=bg_upsampler)
109
+
110
+ # ------------------------ restore ------------------------
111
+ for img_path in img_list:
112
+ # read image
113
+ img_name = os.path.basename(img_path)
114
+ print(f'Processing {img_name} ...')
115
+ basename, ext = os.path.splitext(img_name)
116
+ input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
117
+
118
+ # restore faces and background if necessary
119
+ cropped_faces, restored_faces, restored_img = restorer.enhance(
120
+ input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=True)
121
+
122
+ # save faces
123
+ for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
124
+ # save cropped face
125
+ save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png')
126
+ imwrite(cropped_face, save_crop_path)
127
+ # save restored face
128
+ if args.suffix is not None:
129
+ save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
130
+ else:
131
+ save_face_name = f'{basename}_{idx:02d}.png'
132
+ save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name)
133
+ imwrite(restored_face, save_restore_path)
134
+ # save comparison image
135
+ cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
136
+ imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png'))
137
+
138
+ # save restored img
139
+ if restored_img is not None:
140
+ if args.ext == 'auto':
141
+ extension = ext[1:]
142
+ else:
143
+ extension = args.ext
144
+
145
+ if args.suffix is not None:
146
+ save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}')
147
+ else:
148
+ save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}')
149
+ imwrite(restored_img, save_restore_path)
150
+
151
+ print(f'Results are in the [{args.output}] folder.')
152
+
153
+
154
+ if __name__ == '__main__':
155
+ main()
options/train_gfpgan_v1.yml ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_GFPGANv1_512
3
+ model_type: GFPGANModel
4
+ num_gpu: auto # officially, we use 4 GPUs
5
+ manual_seed: 0
6
+
7
+ # dataset and data loader settings
8
+ datasets:
9
+ train:
10
+ name: FFHQ
11
+ type: FFHQDegradationDataset
12
+ # dataroot_gt: datasets/ffhq/ffhq_512.lmdb
13
+ dataroot_gt: datasets/ffhq/ffhq_512
14
+ io_backend:
15
+ # type: lmdb
16
+ type: disk
17
+
18
+ use_hflip: true
19
+ mean: [0.5, 0.5, 0.5]
20
+ std: [0.5, 0.5, 0.5]
21
+ out_size: 512
22
+
23
+ blur_kernel_size: 41
24
+ kernel_list: ['iso', 'aniso']
25
+ kernel_prob: [0.5, 0.5]
26
+ blur_sigma: [0.1, 10]
27
+ downsample_range: [0.8, 8]
28
+ noise_range: [0, 20]
29
+ jpeg_range: [60, 100]
30
+
31
+ # color jitter and gray
32
+ color_jitter_prob: 0.3
33
+ color_jitter_shift: 20
34
+ color_jitter_pt_prob: 0.3
35
+ gray_prob: 0.01
36
+
37
+ # If you do not want colorization, please set
38
+ # color_jitter_prob: ~
39
+ # color_jitter_pt_prob: ~
40
+ # gray_prob: 0.01
41
+ # gt_gray: True
42
+
43
+ crop_components: true
44
+ component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
45
+ eye_enlarge_ratio: 1.4
46
+
47
+ # data loader
48
+ use_shuffle: true
49
+ num_worker_per_gpu: 6
50
+ batch_size_per_gpu: 3
51
+ dataset_enlarge_ratio: 1
52
+ prefetch_mode: ~
53
+
54
+ val:
55
+ # Please modify accordingly to use your own validation
56
+ # Or comment the val block if do not need validation during training
57
+ name: validation
58
+ type: PairedImageDataset
59
+ dataroot_lq: datasets/faces/validation/input
60
+ dataroot_gt: datasets/faces/validation/reference
61
+ io_backend:
62
+ type: disk
63
+ mean: [0.5, 0.5, 0.5]
64
+ std: [0.5, 0.5, 0.5]
65
+ scale: 1
66
+
67
+ # network structures
68
+ network_g:
69
+ type: GFPGANv1
70
+ out_size: 512
71
+ num_style_feat: 512
72
+ channel_multiplier: 1
73
+ resample_kernel: [1, 3, 3, 1]
74
+ decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
75
+ fix_decoder: true
76
+ num_mlp: 8
77
+ lr_mlp: 0.01
78
+ input_is_latent: true
79
+ different_w: true
80
+ narrow: 1
81
+ sft_half: true
82
+
83
+ network_d:
84
+ type: StyleGAN2Discriminator
85
+ out_size: 512
86
+ channel_multiplier: 1
87
+ resample_kernel: [1, 3, 3, 1]
88
+
89
+ network_d_left_eye:
90
+ type: FacialComponentDiscriminator
91
+
92
+ network_d_right_eye:
93
+ type: FacialComponentDiscriminator
94
+
95
+ network_d_mouth:
96
+ type: FacialComponentDiscriminator
97
+
98
+ network_identity:
99
+ type: ResNetArcFace
100
+ block: IRBlock
101
+ layers: [2, 2, 2, 2]
102
+ use_se: False
103
+
104
+ # path
105
+ path:
106
+ pretrain_network_g: ~
107
+ param_key_g: params_ema
108
+ strict_load_g: ~
109
+ pretrain_network_d: ~
110
+ pretrain_network_d_left_eye: ~
111
+ pretrain_network_d_right_eye: ~
112
+ pretrain_network_d_mouth: ~
113
+ pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
114
+ # resume
115
+ resume_state: ~
116
+ ignore_resume_networks: ['network_identity']
117
+
118
+ # training settings
119
+ train:
120
+ optim_g:
121
+ type: Adam
122
+ lr: !!float 2e-3
123
+ optim_d:
124
+ type: Adam
125
+ lr: !!float 2e-3
126
+ optim_component:
127
+ type: Adam
128
+ lr: !!float 2e-3
129
+
130
+ scheduler:
131
+ type: MultiStepLR
132
+ milestones: [600000, 700000]
133
+ gamma: 0.5
134
+
135
+ total_iter: 800000
136
+ warmup_iter: -1 # no warm up
137
+
138
+ # losses
139
+ # pixel loss
140
+ pixel_opt:
141
+ type: L1Loss
142
+ loss_weight: !!float 1e-1
143
+ reduction: mean
144
+ # L1 loss used in pyramid loss, component style loss and identity loss
145
+ L1_opt:
146
+ type: L1Loss
147
+ loss_weight: 1
148
+ reduction: mean
149
+
150
+ # image pyramid loss
151
+ pyramid_loss_weight: 1
152
+ remove_pyramid_loss: 50000
153
+ # perceptual loss (content and style losses)
154
+ perceptual_opt:
155
+ type: PerceptualLoss
156
+ layer_weights:
157
+ # before relu
158
+ 'conv1_2': 0.1
159
+ 'conv2_2': 0.1
160
+ 'conv3_4': 1
161
+ 'conv4_4': 1
162
+ 'conv5_4': 1
163
+ vgg_type: vgg19
164
+ use_input_norm: true
165
+ perceptual_weight: !!float 1
166
+ style_weight: 50
167
+ range_norm: true
168
+ criterion: l1
169
+ # gan loss
170
+ gan_opt:
171
+ type: GANLoss
172
+ gan_type: wgan_softplus
173
+ loss_weight: !!float 1e-1
174
+ # r1 regularization for discriminator
175
+ r1_reg_weight: 10
176
+ # facial component loss
177
+ gan_component_opt:
178
+ type: GANLoss
179
+ gan_type: vanilla
180
+ real_label_val: 1.0
181
+ fake_label_val: 0.0
182
+ loss_weight: !!float 1
183
+ comp_style_weight: 200
184
+ # identity loss
185
+ identity_weight: 10
186
+
187
+ net_d_iters: 1
188
+ net_d_init_iters: 0
189
+ net_d_reg_every: 16
190
+
191
+ # validation settings
192
+ val:
193
+ val_freq: !!float 5e3
194
+ save_img: true
195
+
196
+ metrics:
197
+ psnr: # metric name
198
+ type: calculate_psnr
199
+ crop_border: 0
200
+ test_y_channel: false
201
+
202
+ # logging settings
203
+ logger:
204
+ print_freq: 100
205
+ save_checkpoint_freq: !!float 5e3
206
+ use_tb_logger: true
207
+ wandb:
208
+ project: ~
209
+ resume_id: ~
210
+
211
+ # dist training settings
212
+ dist_params:
213
+ backend: nccl
214
+ port: 29500
215
+
216
+ find_unused_parameters: true
options/train_gfpgan_v1_simple.yml ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_GFPGANv1_512_simple
3
+ model_type: GFPGANModel
4
+ num_gpu: auto # officially, we use 4 GPUs
5
+ manual_seed: 0
6
+
7
+ # dataset and data loader settings
8
+ datasets:
9
+ train:
10
+ name: FFHQ
11
+ type: FFHQDegradationDataset
12
+ # dataroot_gt: datasets/ffhq/ffhq_512.lmdb
13
+ dataroot_gt: datasets/ffhq/ffhq_512
14
+ io_backend:
15
+ # type: lmdb
16
+ type: disk
17
+
18
+ use_hflip: true
19
+ mean: [0.5, 0.5, 0.5]
20
+ std: [0.5, 0.5, 0.5]
21
+ out_size: 512
22
+
23
+ blur_kernel_size: 41
24
+ kernel_list: ['iso', 'aniso']
25
+ kernel_prob: [0.5, 0.5]
26
+ blur_sigma: [0.1, 10]
27
+ downsample_range: [0.8, 8]
28
+ noise_range: [0, 20]
29
+ jpeg_range: [60, 100]
30
+
31
+ # color jitter and gray
32
+ color_jitter_prob: 0.3
33
+ color_jitter_shift: 20
34
+ color_jitter_pt_prob: 0.3
35
+ gray_prob: 0.01
36
+
37
+ # If you do not want colorization, please set
38
+ # color_jitter_prob: ~
39
+ # color_jitter_pt_prob: ~
40
+ # gray_prob: 0.01
41
+ # gt_gray: True
42
+
43
+ # data loader
44
+ use_shuffle: true
45
+ num_worker_per_gpu: 6
46
+ batch_size_per_gpu: 3
47
+ dataset_enlarge_ratio: 1
48
+ prefetch_mode: ~
49
+
50
+ val:
51
+ # Please modify accordingly to use your own validation
52
+ # Or comment the val block if do not need validation during training
53
+ name: validation
54
+ type: PairedImageDataset
55
+ dataroot_lq: datasets/faces/validation/input
56
+ dataroot_gt: datasets/faces/validation/reference
57
+ io_backend:
58
+ type: disk
59
+ mean: [0.5, 0.5, 0.5]
60
+ std: [0.5, 0.5, 0.5]
61
+ scale: 1
62
+
63
+ # network structures
64
+ network_g:
65
+ type: GFPGANv1
66
+ out_size: 512
67
+ num_style_feat: 512
68
+ channel_multiplier: 1
69
+ resample_kernel: [1, 3, 3, 1]
70
+ decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
71
+ fix_decoder: true
72
+ num_mlp: 8
73
+ lr_mlp: 0.01
74
+ input_is_latent: true
75
+ different_w: true
76
+ narrow: 1
77
+ sft_half: true
78
+
79
+ network_d:
80
+ type: StyleGAN2Discriminator
81
+ out_size: 512
82
+ channel_multiplier: 1
83
+ resample_kernel: [1, 3, 3, 1]
84
+
85
+
86
+ # path
87
+ path:
88
+ pretrain_network_g: ~
89
+ param_key_g: params_ema
90
+ strict_load_g: ~
91
+ pretrain_network_d: ~
92
+ resume_state: ~
93
+
94
+ # training settings
95
+ train:
96
+ optim_g:
97
+ type: Adam
98
+ lr: !!float 2e-3
99
+ optim_d:
100
+ type: Adam
101
+ lr: !!float 2e-3
102
+ optim_component:
103
+ type: Adam
104
+ lr: !!float 2e-3
105
+
106
+ scheduler:
107
+ type: MultiStepLR
108
+ milestones: [600000, 700000]
109
+ gamma: 0.5
110
+
111
+ total_iter: 800000
112
+ warmup_iter: -1 # no warm up
113
+
114
+ # losses
115
+ # pixel loss
116
+ pixel_opt:
117
+ type: L1Loss
118
+ loss_weight: !!float 1e-1
119
+ reduction: mean
120
+ # L1 loss used in pyramid loss, component style loss and identity loss
121
+ L1_opt:
122
+ type: L1Loss
123
+ loss_weight: 1
124
+ reduction: mean
125
+
126
+ # image pyramid loss
127
+ pyramid_loss_weight: 1
128
+ remove_pyramid_loss: 50000
129
+ # perceptual loss (content and style losses)
130
+ perceptual_opt:
131
+ type: PerceptualLoss
132
+ layer_weights:
133
+ # before relu
134
+ 'conv1_2': 0.1
135
+ 'conv2_2': 0.1
136
+ 'conv3_4': 1
137
+ 'conv4_4': 1
138
+ 'conv5_4': 1
139
+ vgg_type: vgg19
140
+ use_input_norm: true
141
+ perceptual_weight: !!float 1
142
+ style_weight: 50
143
+ range_norm: true
144
+ criterion: l1
145
+ # gan loss
146
+ gan_opt:
147
+ type: GANLoss
148
+ gan_type: wgan_softplus
149
+ loss_weight: !!float 1e-1
150
+ # r1 regularization for discriminator
151
+ r1_reg_weight: 10
152
+
153
+ net_d_iters: 1
154
+ net_d_init_iters: 0
155
+ net_d_reg_every: 16
156
+
157
+ # validation settings
158
+ val:
159
+ val_freq: !!float 5e3
160
+ save_img: true
161
+
162
+ metrics:
163
+ psnr: # metric name
164
+ type: calculate_psnr
165
+ crop_border: 0
166
+ test_y_channel: false
167
+
168
+ # logging settings
169
+ logger:
170
+ print_freq: 100
171
+ save_checkpoint_freq: !!float 5e3
172
+ use_tb_logger: true
173
+ wandb:
174
+ project: ~
175
+ resume_id: ~
176
+
177
+ # dist training settings
178
+ dist_params:
179
+ backend: nccl
180
+ port: 29500
181
+
182
+ find_unused_parameters: true
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ basicsr>=1.3.4.0
2
+ facexlib>=0.2.3
3
+ lmdb
4
+ numpy<1.21 # numba requires numpy<1.21,>=1.17
5
+ opencv-python
6
+ pyyaml
7
+ scipy
8
+ tb-nightly
9
+ torch>=1.7
10
+ torchvision
11
+ tqdm
12
+ yapf
scripts/convert_gfpganv_to_clean.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import torch
4
+
5
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
6
+
7
+
8
+ def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
9
+ for ori_k, ori_v in checkpoint_bilinear.items():
10
+ if 'stylegan_decoder' in ori_k:
11
+ if 'style_mlp' in ori_k: # style_mlp_layers
12
+ lr_mul = 0.01
13
+ prefix, name, idx, var = ori_k.split('.')
14
+ idx = (int(idx) * 2) - 1
15
+ crt_k = f'{prefix}.{name}.{idx}.{var}'
16
+ if var == 'weight':
17
+ _, c_in = ori_v.size()
18
+ scale = (1 / math.sqrt(c_in)) * lr_mul
19
+ crt_v = ori_v * scale * 2**0.5
20
+ else:
21
+ crt_v = ori_v * lr_mul * 2**0.5
22
+ checkpoint_clean[crt_k] = crt_v
23
+ elif 'modulation' in ori_k: # modulation in StyleConv
24
+ lr_mul = 1
25
+ crt_k = ori_k
26
+ var = ori_k.split('.')[-1]
27
+ if var == 'weight':
28
+ _, c_in = ori_v.size()
29
+ scale = (1 / math.sqrt(c_in)) * lr_mul
30
+ crt_v = ori_v * scale
31
+ else:
32
+ crt_v = ori_v * lr_mul
33
+ checkpoint_clean[crt_k] = crt_v
34
+ elif 'style_conv' in ori_k:
35
+ # StyleConv in style_conv1 and style_convs
36
+ if 'activate' in ori_k: # FusedLeakyReLU
37
+ # eg. style_conv1.activate.bias
38
+ # eg. style_convs.13.activate.bias
39
+ split_rlt = ori_k.split('.')
40
+ if len(split_rlt) == 4:
41
+ prefix, name, _, var = split_rlt
42
+ crt_k = f'{prefix}.{name}.{var}'
43
+ elif len(split_rlt) == 5:
44
+ prefix, name, idx, _, var = split_rlt
45
+ crt_k = f'{prefix}.{name}.{idx}.{var}'
46
+ crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU
47
+ c = crt_v.size(0)
48
+ checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
49
+ elif 'modulated_conv' in ori_k:
50
+ # eg. style_conv1.modulated_conv.weight
51
+ # eg. style_convs.13.modulated_conv.weight
52
+ _, c_out, c_in, k1, k2 = ori_v.size()
53
+ scale = 1 / math.sqrt(c_in * k1 * k2)
54
+ crt_k = ori_k
55
+ checkpoint_clean[crt_k] = ori_v * scale
56
+ elif 'weight' in ori_k:
57
+ crt_k = ori_k
58
+ checkpoint_clean[crt_k] = ori_v * 2**0.5
59
+ elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs
60
+ if 'modulated_conv' in ori_k:
61
+ # eg. to_rgb1.modulated_conv.weight
62
+ # eg. to_rgbs.5.modulated_conv.weight
63
+ _, c_out, c_in, k1, k2 = ori_v.size()
64
+ scale = 1 / math.sqrt(c_in * k1 * k2)
65
+ crt_k = ori_k
66
+ checkpoint_clean[crt_k] = ori_v * scale
67
+ else:
68
+ crt_k = ori_k
69
+ checkpoint_clean[crt_k] = ori_v
70
+ else:
71
+ crt_k = ori_k
72
+ checkpoint_clean[crt_k] = ori_v
73
+ # end of 'stylegan_decoder'
74
+ elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
75
+ # key name
76
+ name, _, var = ori_k.split('.')
77
+ crt_k = f'{name}.{var}'
78
+ # weight and bias
79
+ if var == 'weight':
80
+ c_out, c_in, k1, k2 = ori_v.size()
81
+ scale = 1 / math.sqrt(c_in * k1 * k2)
82
+ checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
83
+ else:
84
+ checkpoint_clean[crt_k] = ori_v * 2**0.5
85
+ elif 'conv_body' in ori_k:
86
+ if 'conv_body_up' in ori_k:
87
+ ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
88
+ ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
89
+ name1, idx1, name2, _, var = ori_k.split('.')
90
+ crt_k = f'{name1}.{idx1}.{name2}.{var}'
91
+ if name2 == 'skip':
92
+ c_out, c_in, k1, k2 = ori_v.size()
93
+ scale = 1 / math.sqrt(c_in * k1 * k2)
94
+ checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
95
+ else:
96
+ if var == 'weight':
97
+ c_out, c_in, k1, k2 = ori_v.size()
98
+ scale = 1 / math.sqrt(c_in * k1 * k2)
99
+ checkpoint_clean[crt_k] = ori_v * scale
100
+ else:
101
+ checkpoint_clean[crt_k] = ori_v
102
+ if 'conv1' in ori_k:
103
+ checkpoint_clean[crt_k] *= 2**0.5
104
+ elif 'toRGB' in ori_k:
105
+ crt_k = ori_k
106
+ if 'weight' in ori_k:
107
+ c_out, c_in, k1, k2 = ori_v.size()
108
+ scale = 1 / math.sqrt(c_in * k1 * k2)
109
+ checkpoint_clean[crt_k] = ori_v * scale
110
+ else:
111
+ checkpoint_clean[crt_k] = ori_v
112
+ elif 'final_linear' in ori_k:
113
+ crt_k = ori_k
114
+ if 'weight' in ori_k:
115
+ _, c_in = ori_v.size()
116
+ scale = 1 / math.sqrt(c_in)
117
+ checkpoint_clean[crt_k] = ori_v * scale
118
+ else:
119
+ checkpoint_clean[crt_k] = ori_v
120
+ elif 'condition' in ori_k:
121
+ crt_k = ori_k
122
+ if '0.weight' in ori_k:
123
+ c_out, c_in, k1, k2 = ori_v.size()
124
+ scale = 1 / math.sqrt(c_in * k1 * k2)
125
+ checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
126
+ elif '0.bias' in ori_k:
127
+ checkpoint_clean[crt_k] = ori_v * 2**0.5
128
+ elif '2.weight' in ori_k:
129
+ c_out, c_in, k1, k2 = ori_v.size()
130
+ scale = 1 / math.sqrt(c_in * k1 * k2)
131
+ checkpoint_clean[crt_k] = ori_v * scale
132
+ elif '2.bias' in ori_k:
133
+ checkpoint_clean[crt_k] = ori_v
134
+
135
+ return checkpoint_clean
136
+
137
+
138
+ if __name__ == '__main__':
139
+ parser = argparse.ArgumentParser()
140
+ parser.add_argument('--ori_path', type=str, help='Path to the original model')
141
+ parser.add_argument('--narrow', type=float, default=1)
142
+ parser.add_argument('--channel_multiplier', type=float, default=2)
143
+ parser.add_argument('--save_path', type=str)
144
+ args = parser.parse_args()
145
+
146
+ ori_ckpt = torch.load(args.ori_path)['params_ema']
147
+
148
+ net = GFPGANv1Clean(
149
+ 512,
150
+ num_style_feat=512,
151
+ channel_multiplier=args.channel_multiplier,
152
+ decoder_load_path=None,
153
+ fix_decoder=False,
154
+ # for stylegan decoder
155
+ num_mlp=8,
156
+ input_is_latent=True,
157
+ different_w=True,
158
+ narrow=args.narrow,
159
+ sft_half=True)
160
+ crt_ckpt = net.state_dict()
161
+
162
+ crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
163
+ print(f'Save to {args.save_path}.')
164
+ torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)
scripts/parse_landmark.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from basicsr.utils import FileClient, imfrombytes
7
+ from collections import OrderedDict
8
+
9
+ # ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
10
+ # Configurations
11
+ save_img = False
12
+ scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
13
+ enlarge_ratio = 1.4 # only for eyes
14
+ json_path = 'ffhq-dataset-v2.json'
15
+ face_path = 'datasets/ffhq/ffhq_512.lmdb'
16
+ save_path = './FFHQ_eye_mouth_landmarks_512.pth'
17
+
18
+ print('Load JSON metadata...')
19
+ # use the official json file in FFHQ dataset
20
+ with open(json_path, 'rb') as f:
21
+ json_data = json.load(f, object_pairs_hook=OrderedDict)
22
+
23
+ print('Open LMDB file...')
24
+ # read ffhq images
25
+ file_client = FileClient('lmdb', db_paths=face_path)
26
+ with open(os.path.join(face_path, 'meta_info.txt')) as fin:
27
+ paths = [line.split('.')[0] for line in fin]
28
+
29
+ save_dict = {}
30
+
31
+ for item_idx, item in enumerate(json_data.values()):
32
+ print(f'\r{item_idx} / {len(json_data)}, {item["image"]["file_path"]} ', end='', flush=True)
33
+
34
+ # parse landmarks
35
+ lm = np.array(item['image']['face_landmarks'])
36
+ lm = lm * scale
37
+
38
+ item_dict = {}
39
+ # get image
40
+ if save_img:
41
+ img_bytes = file_client.get(paths[item_idx])
42
+ img = imfrombytes(img_bytes, float32=True)
43
+
44
+ # get landmarks for each component
45
+ map_left_eye = list(range(36, 42))
46
+ map_right_eye = list(range(42, 48))
47
+ map_mouth = list(range(48, 68))
48
+
49
+ # eye_left
50
+ mean_left_eye = np.mean(lm[map_left_eye], 0) # (x, y)
51
+ half_len_left_eye = np.max((np.max(np.max(lm[map_left_eye], 0) - np.min(lm[map_left_eye], 0)) / 2, 16))
52
+ item_dict['left_eye'] = [mean_left_eye[0], mean_left_eye[1], half_len_left_eye]
53
+ # mean_left_eye[0] = 512 - mean_left_eye[0] # for testing flip
54
+ half_len_left_eye *= enlarge_ratio
55
+ loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
56
+ if save_img:
57
+ eye_left_img = img[loc_left_eye[1]:loc_left_eye[3], loc_left_eye[0]:loc_left_eye[2], :]
58
+ cv2.imwrite(f'tmp/{item_idx:08d}_eye_left.png', eye_left_img * 255)
59
+
60
+ # eye_right
61
+ mean_right_eye = np.mean(lm[map_right_eye], 0)
62
+ half_len_right_eye = np.max((np.max(np.max(lm[map_right_eye], 0) - np.min(lm[map_right_eye], 0)) / 2, 16))
63
+ item_dict['right_eye'] = [mean_right_eye[0], mean_right_eye[1], half_len_right_eye]
64
+ # mean_right_eye[0] = 512 - mean_right_eye[0] # # for testing flip
65
+ half_len_right_eye *= enlarge_ratio
66
+ loc_right_eye = np.hstack(
67
+ (mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
68
+ if save_img:
69
+ eye_right_img = img[loc_right_eye[1]:loc_right_eye[3], loc_right_eye[0]:loc_right_eye[2], :]
70
+ cv2.imwrite(f'tmp/{item_idx:08d}_eye_right.png', eye_right_img * 255)
71
+
72
+ # mouth
73
+ mean_mouth = np.mean(lm[map_mouth], 0)
74
+ half_len_mouth = np.max((np.max(np.max(lm[map_mouth], 0) - np.min(lm[map_mouth], 0)) / 2, 16))
75
+ item_dict['mouth'] = [mean_mouth[0], mean_mouth[1], half_len_mouth]
76
+ # mean_mouth[0] = 512 - mean_mouth[0] # for testing flip
77
+ loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
78
+ if save_img:
79
+ mouth_img = img[loc_mouth[1]:loc_mouth[3], loc_mouth[0]:loc_mouth[2], :]
80
+ cv2.imwrite(f'tmp/{item_idx:08d}_mouth.png', mouth_img * 255)
81
+
82
+ save_dict[f'{item_idx:08d}'] = item_dict
83
+
84
+ print('Save...')
85
+ torch.save(save_dict, save_path)
setup.cfg ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore =
3
+ # line break before binary operator (W503)
4
+ W503,
5
+ # line break after binary operator (W504)
6
+ W504,
7
+ max-line-length=120
8
+
9
+ [yapf]
10
+ based_on_style = pep8
11
+ column_limit = 120
12
+ blank_line_before_nested_class_or_def = true
13
+ split_before_expression_after_opening_paren = true
14
+
15
+ [isort]
16
+ line_length = 120
17
+ multi_line_output = 0
18
+ known_standard_library = pkg_resources,setuptools
19
+ known_first_party = gfpgan
20
+ known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml
21
+ no_lines_before = STDLIB,LOCALFOLDER
22
+ default_section = THIRDPARTY
23
+
24
+ [codespell]
25
+ skip = .git,./docs/build
26
+ count =
27
+ quiet-level = 3
28
+
29
+ [aliases]
30
+ test=pytest
31
+
32
+ [tool:pytest]
33
+ addopts=tests/
setup.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+ import os
6
+ import subprocess
7
+ import time
8
+
9
+ version_file = 'gfpgan/version.py'
10
+
11
+
12
+ def readme():
13
+ with open('README.md', encoding='utf-8') as f:
14
+ content = f.read()
15
+ return content
16
+
17
+
18
+ def get_git_hash():
19
+
20
+ def _minimal_ext_cmd(cmd):
21
+ # construct minimal environment
22
+ env = {}
23
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
24
+ v = os.environ.get(k)
25
+ if v is not None:
26
+ env[k] = v
27
+ # LANGUAGE is used on win32
28
+ env['LANGUAGE'] = 'C'
29
+ env['LANG'] = 'C'
30
+ env['LC_ALL'] = 'C'
31
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
32
+ return out
33
+
34
+ try:
35
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
36
+ sha = out.strip().decode('ascii')
37
+ except OSError:
38
+ sha = 'unknown'
39
+
40
+ return sha
41
+
42
+
43
+ def get_hash():
44
+ if os.path.exists('.git'):
45
+ sha = get_git_hash()[:7]
46
+ else:
47
+ sha = 'unknown'
48
+
49
+ return sha
50
+
51
+
52
+ def write_version_py():
53
+ content = """# GENERATED VERSION FILE
54
+ # TIME: {}
55
+ __version__ = '{}'
56
+ __gitsha__ = '{}'
57
+ version_info = ({})
58
+ """
59
+ sha = get_hash()
60
+ with open('VERSION', 'r') as f:
61
+ SHORT_VERSION = f.read().strip()
62
+ VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
63
+
64
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
65
+ with open(version_file, 'w') as f:
66
+ f.write(version_file_str)
67
+
68
+
69
+ def get_version():
70
+ with open(version_file, 'r') as f:
71
+ exec(compile(f.read(), version_file, 'exec'))
72
+ return locals()['__version__']
73
+
74
+
75
+ def get_requirements(filename='requirements.txt'):
76
+ here = os.path.dirname(os.path.realpath(__file__))
77
+ with open(os.path.join(here, filename), 'r') as f:
78
+ requires = [line.replace('\n', '') for line in f.readlines()]
79
+ return requires
80
+
81
+
82
+ if __name__ == '__main__':
83
+ write_version_py()
84
+ setup(
85
+ name='gfpgan',
86
+ version=get_version(),
87
+ description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration',
88
+ long_description=readme(),
89
+ long_description_content_type='text/markdown',
90
+ author='Xintao Wang',
91
+ author_email='xintao.wang@outlook.com',
92
+ keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan',
93
+ url='https://github.com/TencentARC/GFPGAN',
94
+ include_package_data=True,
95
+ packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
96
+ classifiers=[
97
+ 'Development Status :: 4 - Beta',
98
+ 'License :: OSI Approved :: Apache Software License',
99
+ 'Operating System :: OS Independent',
100
+ 'Programming Language :: Python :: 3',
101
+ 'Programming Language :: Python :: 3.7',
102
+ 'Programming Language :: Python :: 3.8',
103
+ ],
104
+ license='Apache License Version 2.0',
105
+ setup_requires=['cython', 'numpy'],
106
+ install_requires=get_requirements(),
107
+ zip_safe=False)
tests/data/ffhq_gt.lmdb/data.mdb ADDED
Binary file (455 kB). View file
 
tests/data/ffhq_gt.lmdb/lock.mdb ADDED
Binary file (8.19 kB). View file
 
tests/data/ffhq_gt.lmdb/meta_info.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 00000000.png (512,512,3) 1
tests/data/test_eye_mouth_landmarks.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:131583fca2cc346652f8754eb3c5a0bdeda808686039ff10ead7a26254b72358
3
+ size 943
tests/data/test_ffhq_degradation_dataset.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: UnitTest
2
+ type: FFHQDegradationDataset
3
+ dataroot_gt: tests/data/gt
4
+ io_backend:
5
+ type: disk
6
+
7
+ use_hflip: true
8
+ mean: [0.5, 0.5, 0.5]
9
+ std: [0.5, 0.5, 0.5]
10
+ out_size: 512
11
+
12
+ blur_kernel_size: 41
13
+ kernel_list: ['iso', 'aniso']
14
+ kernel_prob: [0.5, 0.5]
15
+ blur_sigma: [0.1, 10]
16
+ downsample_range: [0.8, 8]
17
+ noise_range: [0, 20]
18
+ jpeg_range: [60, 100]
19
+
20
+ # color jitter and gray
21
+ color_jitter_prob: 1
22
+ color_jitter_shift: 20
23
+ color_jitter_pt_prob: 1
24
+ gray_prob: 1
tests/data/test_gfpgan_model.yml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_gpu: 1
2
+ manual_seed: 0
3
+ is_train: True
4
+ dist: False
5
+
6
+ # network structures
7
+ network_g:
8
+ type: GFPGANv1
9
+ out_size: 512
10
+ num_style_feat: 512
11
+ channel_multiplier: 1
12
+ resample_kernel: [1, 3, 3, 1]
13
+ decoder_load_path: ~
14
+ fix_decoder: true
15
+ num_mlp: 8
16
+ lr_mlp: 0.01
17
+ input_is_latent: true
18
+ different_w: true
19
+ narrow: 0.5
20
+ sft_half: true
21
+
22
+ network_d:
23
+ type: StyleGAN2Discriminator
24
+ out_size: 512
25
+ channel_multiplier: 1
26
+ resample_kernel: [1, 3, 3, 1]
27
+
28
+ network_d_left_eye:
29
+ type: FacialComponentDiscriminator
30
+
31
+ network_d_right_eye:
32
+ type: FacialComponentDiscriminator
33
+
34
+ network_d_mouth:
35
+ type: FacialComponentDiscriminator
36
+
37
+ network_identity:
38
+ type: ResNetArcFace
39
+ block: IRBlock
40
+ layers: [2, 2, 2, 2]
41
+ use_se: False
42
+
43
+ # path
44
+ path:
45
+ pretrain_network_g: ~
46
+ param_key_g: params_ema
47
+ strict_load_g: ~
48
+ pretrain_network_d: ~
49
+ pretrain_network_d_left_eye: ~
50
+ pretrain_network_d_right_eye: ~
51
+ pretrain_network_d_mouth: ~
52
+ pretrain_network_identity: ~
53
+ # resume
54
+ resume_state: ~
55
+ ignore_resume_networks: ['network_identity']
56
+
57
+ # training settings
58
+ train:
59
+ optim_g:
60
+ type: Adam
61
+ lr: !!float 2e-3
62
+ optim_d:
63
+ type: Adam
64
+ lr: !!float 2e-3
65
+ optim_component:
66
+ type: Adam
67
+ lr: !!float 2e-3
68
+
69
+ scheduler:
70
+ type: MultiStepLR
71
+ milestones: [600000, 700000]
72
+ gamma: 0.5
73
+
74
+ total_iter: 800000
75
+ warmup_iter: -1 # no warm up
76
+
77
+ # losses
78
+ # pixel loss
79
+ pixel_opt:
80
+ type: L1Loss
81
+ loss_weight: !!float 1e-1
82
+ reduction: mean
83
+ # L1 loss used in pyramid loss, component style loss and identity loss
84
+ L1_opt:
85
+ type: L1Loss
86
+ loss_weight: 1
87
+ reduction: mean
88
+
89
+ # image pyramid loss
90
+ pyramid_loss_weight: 1
91
+ remove_pyramid_loss: 50000
92
+ # perceptual loss (content and style losses)
93
+ perceptual_opt:
94
+ type: PerceptualLoss
95
+ layer_weights:
96
+ # before relu
97
+ 'conv1_2': 0.1
98
+ 'conv2_2': 0.1
99
+ 'conv3_4': 1
100
+ 'conv4_4': 1
101
+ 'conv5_4': 1
102
+ vgg_type: vgg19
103
+ use_input_norm: true
104
+ perceptual_weight: !!float 1
105
+ style_weight: 50
106
+ range_norm: true
107
+ criterion: l1
108
+ # gan loss
109
+ gan_opt:
110
+ type: GANLoss
111
+ gan_type: wgan_softplus
112
+ loss_weight: !!float 1e-1
113
+ # r1 regularization for discriminator
114
+ r1_reg_weight: 10
115
+ # facial component loss
116
+ gan_component_opt:
117
+ type: GANLoss
118
+ gan_type: vanilla
119
+ real_label_val: 1.0
120
+ fake_label_val: 0.0
121
+ loss_weight: !!float 1
122
+ comp_style_weight: 200
123
+ # identity loss
124
+ identity_weight: 10
125
+
126
+ net_d_iters: 1
127
+ net_d_init_iters: 0
128
+ net_d_reg_every: 1
129
+
130
+ # validation settings
131
+ val:
132
+ val_freq: !!float 5e3
133
+ save_img: True
134
+ use_pbar: True
135
+
136
+ metrics:
137
+ psnr: # metric name
138
+ type: calculate_psnr
139
+ crop_border: 0
140
+ test_y_channel: false
tests/test_arcface_arch.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
4
+
5
+
6
+ def test_resnetarcface():
7
+ """Test arch: ResNetArcFace."""
8
+
9
+ # model init and forward (gpu)
10
+ if torch.cuda.is_available():
11
+ net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
12
+ img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
13
+ output = net(img)
14
+ assert output.shape == (1, 512)
15
+
16
+ # -------------------- without SE block ----------------------- #
17
+ net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
18
+ output = net(img)
19
+ assert output.shape == (1, 512)
20
+
21
+
22
+ def test_basicblock():
23
+ """Test the BasicBlock in arcface_arch"""
24
+ block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
25
+ img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
26
+ output = block(img)
27
+ assert output.shape == (1, 3, 12, 12)
28
+
29
+ # ----------------- use the downsmaple module--------------- #
30
+ downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
31
+ block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
32
+ img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
33
+ output = block(img)
34
+ assert output.shape == (1, 3, 6, 6)
35
+
36
+
37
+ def test_bottleneck():
38
+ """Test the Bottleneck in arcface_arch"""
39
+ block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
40
+ img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
41
+ output = block(img)
42
+ assert output.shape == (1, 4, 12, 12)
43
+
44
+ # ----------------- use the downsmaple module--------------- #
45
+ downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
46
+ block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
47
+ img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
48
+ output = block(img)
49
+ assert output.shape == (1, 4, 6, 6)
tests/test_ffhq_degradation_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import yaml
3
+
4
+ from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset
5
+
6
+
7
+ def test_ffhq_degradation_dataset():
8
+
9
+ with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f:
10
+ opt = yaml.load(f, Loader=yaml.FullLoader)
11
+
12
+ dataset = FFHQDegradationDataset(opt)
13
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
14
+ assert len(dataset) == 1 # whether to read correct meta info
15
+ assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
16
+ assert dataset.color_jitter_prob == 1
17
+
18
+ # test __getitem__
19
+ result = dataset.__getitem__(0)
20
+ # check returned keys
21
+ expected_keys = ['gt', 'lq', 'gt_path']
22
+ assert set(expected_keys).issubset(set(result.keys()))
23
+ # check shape and contents
24
+ assert result['gt'].shape == (3, 512, 512)
25
+ assert result['lq'].shape == (3, 512, 512)
26
+ assert result['gt_path'] == 'tests/data/gt/00000000.png'
27
+
28
+ # ------------------ test with probability = 0 -------------------- #
29
+ opt['color_jitter_prob'] = 0
30
+ opt['color_jitter_pt_prob'] = 0
31
+ opt['gray_prob'] = 0
32
+ opt['io_backend'] = dict(type='disk')
33
+ dataset = FFHQDegradationDataset(opt)
34
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
35
+ assert len(dataset) == 1 # whether to read correct meta info
36
+ assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
37
+ assert dataset.color_jitter_prob == 0
38
+
39
+ # test __getitem__
40
+ result = dataset.__getitem__(0)
41
+ # check returned keys
42
+ expected_keys = ['gt', 'lq', 'gt_path']
43
+ assert set(expected_keys).issubset(set(result.keys()))
44
+ # check shape and contents
45
+ assert result['gt'].shape == (3, 512, 512)
46
+ assert result['lq'].shape == (3, 512, 512)
47
+ assert result['gt_path'] == 'tests/data/gt/00000000.png'
48
+
49
+ # ------------------ test lmdb backend -------------------- #
50
+ opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb'
51
+ opt['io_backend'] = dict(type='lmdb')
52
+
53
+ dataset = FFHQDegradationDataset(opt)
54
+ assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
55
+ assert len(dataset) == 1 # whether to read correct meta info
56
+ assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
57
+ assert dataset.color_jitter_prob == 0
58
+
59
+ # test __getitem__
60
+ result = dataset.__getitem__(0)
61
+ # check returned keys
62
+ expected_keys = ['gt', 'lq', 'gt_path']
63
+ assert set(expected_keys).issubset(set(result.keys()))
64
+ # check shape and contents
65
+ assert result['gt'].shape == (3, 512, 512)
66
+ assert result['lq'].shape == (3, 512, 512)
67
+ assert result['gt_path'] == '00000000'
68
+
69
+ # ------------------ test with crop_components -------------------- #
70
+ opt['crop_components'] = True
71
+ opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth'
72
+ opt['eye_enlarge_ratio'] = 1.4
73
+ opt['gt_gray'] = True
74
+ opt['io_backend'] = dict(type='lmdb')
75
+
76
+ dataset = FFHQDegradationDataset(opt)
77
+ assert dataset.crop_components is True
78
+
79
+ # test __getitem__
80
+ result = dataset.__getitem__(0)
81
+ # check returned keys
82
+ expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth']
83
+ assert set(expected_keys).issubset(set(result.keys()))
84
+ # check shape and contents
85
+ assert result['gt'].shape == (3, 512, 512)
86
+ assert result['lq'].shape == (3, 512, 512)
87
+ assert result['gt_path'] == '00000000'
88
+ assert result['loc_left_eye'].shape == (4, )
89
+ assert result['loc_right_eye'].shape == (4, )
90
+ assert result['loc_mouth'].shape == (4, )
91
+
92
+ # ------------------ lmdb backend should have paths ends with lmdb -------------------- #
93
+ with pytest.raises(ValueError):
94
+ opt['dataroot_gt'] = 'tests/data/gt'
95
+ opt['io_backend'] = dict(type='lmdb')
96
+ dataset = FFHQDegradationDataset(opt)
tests/test_gfpgan_arch.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
4
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
5
+
6
+
7
+ def test_stylegan2generatorsft():
8
+ """Test arch: StyleGAN2GeneratorSFT."""
9
+
10
+ # model init and forward (gpu)
11
+ if torch.cuda.is_available():
12
+ net = StyleGAN2GeneratorSFT(
13
+ out_size=32,
14
+ num_style_feat=512,
15
+ num_mlp=8,
16
+ channel_multiplier=1,
17
+ resample_kernel=(1, 3, 3, 1),
18
+ lr_mlp=0.01,
19
+ narrow=1,
20
+ sft_half=False).cuda().eval()
21
+ style = torch.rand((1, 512), dtype=torch.float32).cuda()
22
+ condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
23
+ condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
24
+ condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
25
+ conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
26
+ output = net([style], conditions)
27
+ assert output[0].shape == (1, 3, 32, 32)
28
+ assert output[1] is None
29
+
30
+ # -------------------- with return_latents ----------------------- #
31
+ output = net([style], conditions, return_latents=True)
32
+ assert output[0].shape == (1, 3, 32, 32)
33
+ assert len(output[1]) == 1
34
+ # check latent
35
+ assert output[1][0].shape == (8, 512)
36
+
37
+ # -------------------- with randomize_noise = False ----------------------- #
38
+ output = net([style], conditions, randomize_noise=False)
39
+ assert output[0].shape == (1, 3, 32, 32)
40
+ assert output[1] is None
41
+
42
+ # -------------------- with truncation = 0.5 and mixing----------------------- #
43
+ output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
44
+ assert output[0].shape == (1, 3, 32, 32)
45
+ assert output[1] is None
46
+
47
+
48
+ def test_gfpganv1():
49
+ """Test arch: GFPGANv1."""
50
+
51
+ # model init and forward (gpu)
52
+ if torch.cuda.is_available():
53
+ net = GFPGANv1(
54
+ out_size=32,
55
+ num_style_feat=512,
56
+ channel_multiplier=1,
57
+ resample_kernel=(1, 3, 3, 1),
58
+ decoder_load_path=None,
59
+ fix_decoder=True,
60
+ # for stylegan decoder
61
+ num_mlp=8,
62
+ lr_mlp=0.01,
63
+ input_is_latent=False,
64
+ different_w=False,
65
+ narrow=1,
66
+ sft_half=True).cuda().eval()
67
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
68
+ output = net(img)
69
+ assert output[0].shape == (1, 3, 32, 32)
70
+ assert len(output[1]) == 3
71
+ # check out_rgbs for intermediate loss
72
+ assert output[1][0].shape == (1, 3, 8, 8)
73
+ assert output[1][1].shape == (1, 3, 16, 16)
74
+ assert output[1][2].shape == (1, 3, 32, 32)
75
+
76
+ # -------------------- with different_w = True ----------------------- #
77
+ net = GFPGANv1(
78
+ out_size=32,
79
+ num_style_feat=512,
80
+ channel_multiplier=1,
81
+ resample_kernel=(1, 3, 3, 1),
82
+ decoder_load_path=None,
83
+ fix_decoder=True,
84
+ # for stylegan decoder
85
+ num_mlp=8,
86
+ lr_mlp=0.01,
87
+ input_is_latent=False,
88
+ different_w=True,
89
+ narrow=1,
90
+ sft_half=True).cuda().eval()
91
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
92
+ output = net(img)
93
+ assert output[0].shape == (1, 3, 32, 32)
94
+ assert len(output[1]) == 3
95
+ # check out_rgbs for intermediate loss
96
+ assert output[1][0].shape == (1, 3, 8, 8)
97
+ assert output[1][1].shape == (1, 3, 16, 16)
98
+ assert output[1][2].shape == (1, 3, 32, 32)
99
+
100
+
101
+ def test_facialcomponentdiscriminator():
102
+ """Test arch: FacialComponentDiscriminator."""
103
+
104
+ # model init and forward (gpu)
105
+ if torch.cuda.is_available():
106
+ net = FacialComponentDiscriminator().cuda().eval()
107
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
108
+ output = net(img)
109
+ assert len(output) == 2
110
+ assert output[0].shape == (1, 1, 8, 8)
111
+ assert output[1] is None
112
+
113
+ # -------------------- return intermediate features ----------------------- #
114
+ output = net(img, return_feats=True)
115
+ assert len(output) == 2
116
+ assert output[0].shape == (1, 1, 8, 8)
117
+ assert len(output[1]) == 2
118
+ assert output[1][0].shape == (1, 128, 16, 16)
119
+ assert output[1][1].shape == (1, 256, 8, 8)
120
+
121
+
122
+ def test_stylegan2generatorcsft():
123
+ """Test arch: StyleGAN2GeneratorCSFT."""
124
+
125
+ # model init and forward (gpu)
126
+ if torch.cuda.is_available():
127
+ net = StyleGAN2GeneratorCSFT(
128
+ out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
129
+ style = torch.rand((1, 512), dtype=torch.float32).cuda()
130
+ condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
131
+ condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
132
+ condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
133
+ conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
134
+ output = net([style], conditions)
135
+ assert output[0].shape == (1, 3, 32, 32)
136
+ assert output[1] is None
137
+
138
+ # -------------------- with return_latents ----------------------- #
139
+ output = net([style], conditions, return_latents=True)
140
+ assert output[0].shape == (1, 3, 32, 32)
141
+ assert len(output[1]) == 1
142
+ # check latent
143
+ assert output[1][0].shape == (8, 512)
144
+
145
+ # -------------------- with randomize_noise = False ----------------------- #
146
+ output = net([style], conditions, randomize_noise=False)
147
+ assert output[0].shape == (1, 3, 32, 32)
148
+ assert output[1] is None
149
+
150
+ # -------------------- with truncation = 0.5 and mixing----------------------- #
151
+ output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
152
+ assert output[0].shape == (1, 3, 32, 32)
153
+ assert output[1] is None
154
+
155
+
156
+ def test_gfpganv1clean():
157
+ """Test arch: GFPGANv1Clean."""
158
+
159
+ # model init and forward (gpu)
160
+ if torch.cuda.is_available():
161
+ net = GFPGANv1Clean(
162
+ out_size=32,
163
+ num_style_feat=512,
164
+ channel_multiplier=1,
165
+ decoder_load_path=None,
166
+ fix_decoder=True,
167
+ # for stylegan decoder
168
+ num_mlp=8,
169
+ input_is_latent=False,
170
+ different_w=False,
171
+ narrow=1,
172
+ sft_half=True).cuda().eval()
173
+
174
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
175
+ output = net(img)
176
+ assert output[0].shape == (1, 3, 32, 32)
177
+ assert len(output[1]) == 3
178
+ # check out_rgbs for intermediate loss
179
+ assert output[1][0].shape == (1, 3, 8, 8)
180
+ assert output[1][1].shape == (1, 3, 16, 16)
181
+ assert output[1][2].shape == (1, 3, 32, 32)
182
+
183
+ # -------------------- with different_w = True ----------------------- #
184
+ net = GFPGANv1Clean(
185
+ out_size=32,
186
+ num_style_feat=512,
187
+ channel_multiplier=1,
188
+ decoder_load_path=None,
189
+ fix_decoder=True,
190
+ # for stylegan decoder
191
+ num_mlp=8,
192
+ input_is_latent=False,
193
+ different_w=True,
194
+ narrow=1,
195
+ sft_half=True).cuda().eval()
196
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
197
+ output = net(img)
198
+ assert output[0].shape == (1, 3, 32, 32)
199
+ assert len(output[1]) == 3
200
+ # check out_rgbs for intermediate loss
201
+ assert output[1][0].shape == (1, 3, 8, 8)
202
+ assert output[1][1].shape == (1, 3, 16, 16)
203
+ assert output[1][2].shape == (1, 3, 32, 32)
tests/test_gfpgan_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import torch
3
+ import yaml
4
+ from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
5
+ from basicsr.data.paired_image_dataset import PairedImageDataset
6
+ from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
7
+
8
+ from gfpgan.archs.arcface_arch import ResNetArcFace
9
+ from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
10
+ from gfpgan.models.gfpgan_model import GFPGANModel
11
+
12
+
13
+ def test_gfpgan_model():
14
+ with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
15
+ opt = yaml.load(f, Loader=yaml.FullLoader)
16
+
17
+ # build model
18
+ model = GFPGANModel(opt)
19
+ # test attributes
20
+ assert model.__class__.__name__ == 'GFPGANModel'
21
+ assert isinstance(model.net_g, GFPGANv1) # generator
22
+ assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
23
+ # facial component discriminators
24
+ assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
25
+ assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
26
+ assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
27
+ # identity network
28
+ assert isinstance(model.network_identity, ResNetArcFace)
29
+ # losses
30
+ assert isinstance(model.cri_pix, L1Loss)
31
+ assert isinstance(model.cri_perceptual, PerceptualLoss)
32
+ assert isinstance(model.cri_gan, GANLoss)
33
+ assert isinstance(model.cri_l1, L1Loss)
34
+ # optimizer
35
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
36
+ assert isinstance(model.optimizers[1], torch.optim.Adam)
37
+
38
+ # prepare data
39
+ gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
40
+ lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
41
+ loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
42
+ loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
43
+ loc_mouth = torch.rand((1, 4), dtype=torch.float32)
44
+ data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
45
+ model.feed_data(data)
46
+ # check data shape
47
+ assert model.lq.shape == (1, 3, 512, 512)
48
+ assert model.gt.shape == (1, 3, 512, 512)
49
+ assert model.loc_left_eyes.shape == (1, 4)
50
+ assert model.loc_right_eyes.shape == (1, 4)
51
+ assert model.loc_mouths.shape == (1, 4)
52
+
53
+ # ----------------- test optimize_parameters -------------------- #
54
+ model.feed_data(data)
55
+ model.optimize_parameters(1)
56
+ assert model.output.shape == (1, 3, 512, 512)
57
+ assert isinstance(model.log_dict, dict)
58
+ # check returned keys
59
+ expected_keys = [
60
+ 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
61
+ 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
62
+ 'l_d_right_eye', 'l_d_mouth'
63
+ ]
64
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
65
+
66
+ # ----------------- remove pyramid_loss_weight-------------------- #
67
+ model.feed_data(data)
68
+ model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
69
+ assert model.output.shape == (1, 3, 512, 512)
70
+ assert isinstance(model.log_dict, dict)
71
+ # check returned keys
72
+ expected_keys = [
73
+ 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
74
+ 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
75
+ 'l_d_right_eye', 'l_d_mouth'
76
+ ]
77
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
78
+
79
+ # ----------------- test save -------------------- #
80
+ with tempfile.TemporaryDirectory() as tmpdir:
81
+ model.opt['path']['models'] = tmpdir
82
+ model.opt['path']['training_states'] = tmpdir
83
+ model.save(0, 1)
84
+
85
+ # ----------------- test the test function -------------------- #
86
+ model.test()
87
+ assert model.output.shape == (1, 3, 512, 512)
88
+ # delete net_g_ema
89
+ model.__delattr__('net_g_ema')
90
+ model.test()
91
+ assert model.output.shape == (1, 3, 512, 512)
92
+ assert model.net_g.training is True # should back to training mode after testing
93
+
94
+ # ----------------- test nondist_validation -------------------- #
95
+ # construct dataloader
96
+ dataset_opt = dict(
97
+ name='Demo',
98
+ dataroot_gt='tests/data/gt',
99
+ dataroot_lq='tests/data/gt',
100
+ io_backend=dict(type='disk'),
101
+ scale=4,
102
+ phase='val')
103
+ dataset = PairedImageDataset(dataset_opt)
104
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
105
+ assert model.is_train is True
106
+ with tempfile.TemporaryDirectory() as tmpdir:
107
+ model.opt['path']['visualization'] = tmpdir
108
+ model.nondist_validation(dataloader, 1, None, save_img=True)
109
+ assert model.is_train is True
110
+ # check metric_results
111
+ assert 'psnr' in model.metric_results
112
+ assert isinstance(model.metric_results['psnr'], float)
113
+
114
+ # validation
115
+ with tempfile.TemporaryDirectory() as tmpdir:
116
+ model.opt['is_train'] = False
117
+ model.opt['val']['suffix'] = 'test'
118
+ model.opt['path']['visualization'] = tmpdir
119
+ model.opt['val']['pbar'] = True
120
+ model.nondist_validation(dataloader, 1, None, save_img=True)
121
+ # check metric_results
122
+ assert 'psnr' in model.metric_results
123
+ assert isinstance(model.metric_results['psnr'], float)
124
+
125
+ # if opt['val']['suffix'] is None
126
+ model.opt['val']['suffix'] = None
127
+ model.opt['name'] = 'demo'
128
+ model.opt['path']['visualization'] = tmpdir
129
+ model.nondist_validation(dataloader, 1, None, save_img=True)
130
+ # check metric_results
131
+ assert 'psnr' in model.metric_results
132
+ assert isinstance(model.metric_results['psnr'], float)
tests/test_stylegan2_clean_arch.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean
4
+
5
+
6
+ def test_stylegan2generatorclean():
7
+ """Test arch: StyleGAN2GeneratorClean."""
8
+
9
+ # model init and forward (gpu)
10
+ if torch.cuda.is_available():
11
+ net = StyleGAN2GeneratorClean(
12
+ out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval()
13
+ style = torch.rand((1, 512), dtype=torch.float32).cuda()
14
+ output = net([style], input_is_latent=False)
15
+ assert output[0].shape == (1, 3, 32, 32)
16
+ assert output[1] is None
17
+
18
+ # -------------------- with return_latents ----------------------- #
19
+ output = net([style], input_is_latent=True, return_latents=True)
20
+ assert output[0].shape == (1, 3, 32, 32)
21
+ assert len(output[1]) == 1
22
+ # check latent
23
+ assert output[1][0].shape == (8, 512)
24
+
25
+ # -------------------- with randomize_noise = False ----------------------- #
26
+ output = net([style], randomize_noise=False)
27
+ assert output[0].shape == (1, 3, 32, 32)
28
+ assert output[1] is None
29
+
30
+ # -------------------- with truncation = 0.5 and mixing----------------------- #
31
+ output = net([style, style], truncation=0.5, truncation_latent=style)
32
+ assert output[0].shape == (1, 3, 32, 32)
33
+ assert output[1] is None
34
+
35
+ # ------------------ test make_noise ----------------------- #
36
+ out = net.make_noise()
37
+ assert len(out) == 7
38
+ assert out[0].shape == (1, 1, 4, 4)
39
+ assert out[1].shape == (1, 1, 8, 8)
40
+ assert out[2].shape == (1, 1, 8, 8)
41
+ assert out[3].shape == (1, 1, 16, 16)
42
+ assert out[4].shape == (1, 1, 16, 16)
43
+ assert out[5].shape == (1, 1, 32, 32)
44
+ assert out[6].shape == (1, 1, 32, 32)
45
+
46
+ # ------------------ test get_latent ----------------------- #
47
+ out = net.get_latent(style)
48
+ assert out.shape == (1, 512)
49
+
50
+ # ------------------ test mean_latent ----------------------- #
51
+ out = net.mean_latent(2)
52
+ assert out.shape == (1, 512)
tests/test_utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
3
+
4
+ from gfpgan.archs.gfpganv1_arch import GFPGANv1
5
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
6
+ from gfpgan.utils import GFPGANer
7
+
8
+
9
+ def test_gfpganer():
10
+ # initialize with the clean model
11
+ restorer = GFPGANer(
12
+ model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth',
13
+ upscale=2,
14
+ arch='clean',
15
+ channel_multiplier=2,
16
+ bg_upsampler=None)
17
+ # test attribute
18
+ assert isinstance(restorer.gfpgan, GFPGANv1Clean)
19
+ assert isinstance(restorer.face_helper, FaceRestoreHelper)
20
+
21
+ # initialize with the original model
22
+ restorer = GFPGANer(
23
+ model_path='experiments/pretrained_models/GFPGANv1.pth',
24
+ upscale=2,
25
+ arch='original',
26
+ channel_multiplier=1,
27
+ bg_upsampler=None)
28
+ # test attribute
29
+ assert isinstance(restorer.gfpgan, GFPGANv1)
30
+ assert isinstance(restorer.face_helper, FaceRestoreHelper)
31
+
32
+ # ------------------ test enhance ---------------- #
33
+ img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR)
34
+ result = restorer.enhance(img, has_aligned=False, paste_back=True)
35
+ assert result[0][0].shape == (512, 512, 3)
36
+ assert result[1][0].shape == (512, 512, 3)
37
+ assert result[2].shape == (1024, 1024, 3)
38
+
39
+ # with has_aligned=True
40
+ result = restorer.enhance(img, has_aligned=True, paste_back=False)
41
+ assert result[0][0].shape == (512, 512, 3)
42
+ assert result[1][0].shape == (512, 512, 3)
43
+ assert result[2] is None