zamen01 commited on
Commit
84595ee
1 Parent(s): e758faa

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +157 -0
  3. Dockerfile +28 -0
  4. DragGAN.gif +3 -0
  5. LICENSE.txt +97 -0
  6. README.md +140 -8
  7. __pycache__/legacy.cpython-39.pyc +0 -0
  8. arial.ttf +0 -0
  9. checkpoints/stylegan2-afhqcat-512x512.pkl +3 -0
  10. checkpoints/stylegan2-car-config-f.pkl +3 -0
  11. checkpoints/stylegan2-cat-config-f.pkl +3 -0
  12. checkpoints/stylegan2-ffhq-512x512.pkl +3 -0
  13. checkpoints/stylegan2-lhq-256x256.pkl +3 -0
  14. checkpoints/stylegan2_dogs_1024_pytorch.pkl +3 -0
  15. checkpoints/stylegan2_elephants_512_pytorch.pkl +3 -0
  16. checkpoints/stylegan2_horses_256_pytorch.pkl +3 -0
  17. checkpoints/stylegan2_lions_512_pytorch.pkl +3 -0
  18. checkpoints/stylegan_human_v2_512.pkl +3 -0
  19. dnnlib/__init__.py +9 -0
  20. dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
  21. dnnlib/__pycache__/util.cpython-39.pyc +0 -0
  22. dnnlib/util.py +491 -0
  23. environment.yml +27 -0
  24. gen_images.py +150 -0
  25. gradio_utils/__init__.py +9 -0
  26. gradio_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  27. gradio_utils/__pycache__/utils.cpython-39.pyc +0 -0
  28. gradio_utils/utils.py +154 -0
  29. gui_utils/__init__.py +9 -0
  30. gui_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  31. gui_utils/__pycache__/gl_utils.cpython-39.pyc +0 -0
  32. gui_utils/__pycache__/glfw_window.cpython-39.pyc +0 -0
  33. gui_utils/__pycache__/imgui_utils.cpython-39.pyc +0 -0
  34. gui_utils/__pycache__/imgui_window.cpython-39.pyc +0 -0
  35. gui_utils/__pycache__/text_utils.cpython-39.pyc +0 -0
  36. gui_utils/gl_utils.py +416 -0
  37. gui_utils/glfw_window.py +229 -0
  38. gui_utils/imgui_utils.py +191 -0
  39. gui_utils/imgui_window.py +103 -0
  40. gui_utils/text_utils.py +123 -0
  41. legacy.py +323 -0
  42. requirements.txt +13 -0
  43. scripts/download_model.py +78 -0
  44. scripts/download_models.json +10 -0
  45. scripts/gui.bat +12 -0
  46. scripts/gui.sh +11 -0
  47. stylegan_human/.gitignore +10 -0
  48. stylegan_human/PP_HumanSeg/deploy/infer.py +180 -0
  49. stylegan_human/PP_HumanSeg/export_model/download_export_model.py +44 -0
  50. stylegan_human/PP_HumanSeg/pretrained_model/download_pretrained_model.py +44 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ DragGAN.gif filter=lfs diff=lfs merge=lfs -text
37
+ stylegan_human/img/demo_V5_thumbnails-min.png filter=lfs diff=lfs merge=lfs -text
38
+ stylegan_human/img/preview_samples1.png filter=lfs diff=lfs merge=lfs -text
39
+ stylegan_human/img/test/test.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by .ignore support plugin (hsz.mobi)
2
+ ### Python template
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ env/
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *,cover
48
+ .hypothesis/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+
58
+ # Flask stuff:
59
+ instance/
60
+ .webassets-cache
61
+
62
+ # Scrapy stuff:
63
+ .scrapy
64
+
65
+ # Sphinx documentation
66
+ docs/_build/
67
+
68
+ # PyBuilder
69
+ target/
70
+
71
+ # IPython Notebook
72
+ .ipynb_checkpoints
73
+
74
+ # pyenv
75
+ .python-version
76
+
77
+ # celery beat schedule file
78
+ celerybeat-schedule
79
+
80
+ # dotenv
81
+ .env
82
+
83
+ # virtualenv
84
+ venv/
85
+ ENV/
86
+
87
+ # Spyder project settings
88
+ .spyderproject
89
+
90
+ # Rope project settings
91
+ .ropeproject
92
+ ### VirtualEnv template
93
+ # Virtualenv
94
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
95
+ .Python
96
+ [Bb]in
97
+ [Ii]nclude
98
+ [Ll]ib
99
+ [Ll]ib64
100
+ [Ll]ocal
101
+ [Ss]cripts
102
+ !scripts/
103
+ pyvenv.cfg
104
+ .venv
105
+ pip-selfcheck.json
106
+ ### JetBrains template
107
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
108
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
109
+
110
+ # User-specific stuff:
111
+ .idea/workspace.xml
112
+ .idea/tasks.xml
113
+ .idea/dictionaries
114
+ .idea/vcs.xml
115
+ .idea/jsLibraryMappings.xml
116
+
117
+ # Sensitive or high-churn files:
118
+ .idea/dataSources.ids
119
+ .idea/dataSources.xml
120
+ .idea/dataSources.local.xml
121
+ .idea/sqlDataSources.xml
122
+ .idea/dynamic.xml
123
+ .idea/uiDesigner.xml
124
+
125
+ # Gradle:
126
+ .idea/gradle.xml
127
+ .idea/libraries
128
+
129
+ # Mongo Explorer plugin:
130
+ .idea/mongoSettings.xml
131
+
132
+ .idea/
133
+
134
+ ## File-based project format:
135
+ *.iws
136
+
137
+ ## Plugin-specific files:
138
+
139
+ # IntelliJ
140
+ /out/
141
+
142
+ # mpeltonen/sbt-idea plugin
143
+ .idea_modules/
144
+
145
+ # JIRA plugin
146
+ atlassian-ide-plugin.xml
147
+
148
+ # Crashlytics plugin (for Android Studio and IntelliJ)
149
+ com_crashlytics_export_strings.xml
150
+ crashlytics.properties
151
+ crashlytics-build.properties
152
+ fabric.properties
153
+
154
+ # Mac related
155
+ .DS_Store
156
+
157
+ checkpoints
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:23.05-py3
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE 1
4
+ ENV PYTHONUNBUFFERED 1
5
+
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ make \
8
+ pkgconf \
9
+ xz-utils \
10
+ xorg-dev \
11
+ libgl1-mesa-dev \
12
+ libglu1-mesa-dev \
13
+ libxrandr-dev \
14
+ libxinerama-dev \
15
+ libxcursor-dev \
16
+ libxi-dev \
17
+ libxxf86vm-dev \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ RUN pip install --no-cache-dir --upgrade pip
21
+
22
+ COPY requirements.txt .
23
+ RUN pip install --no-cache-dir -r requirements.txt
24
+
25
+ WORKDIR /workspace
26
+
27
+ RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
28
+ ENTRYPOINT ["/entry.sh"]
DragGAN.gif ADDED

Git LFS Details

  • SHA256: 2eab11d4dd1f11c2efacfcde385899b0164e241a7823eb050ab2e021f337225a
  • Pointer size: 133 Bytes
  • Size of remote file: 21.6 MB
LICENSE.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for StyleGAN3
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ "Licensor" means any person or entity that distributes its Work.
12
+
13
+ "Software" means the original work of authorship made available under
14
+ this License.
15
+
16
+ "Work" means the Software and any additions to or derivative works of
17
+ the Software that are made available under this License.
18
+
19
+ The terms "reproduce," "reproduction," "derivative works," and
20
+ "distribution" have the meaning as provided under U.S. copyright law;
21
+ provided, however, that for the purposes of this License, derivative
22
+ works shall not include works that remain separable from, or merely
23
+ link (or bind by name) to the interfaces of, the Work.
24
+
25
+ Works, including the Software, are "made available" under this License
26
+ by including in or with the Work either (a) a copyright notice
27
+ referencing the applicability of this License to the Work, or (b) a
28
+ copy of this License.
29
+
30
+ 2. License Grants
31
+
32
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
33
+ License, each Licensor grants to you a perpetual, worldwide,
34
+ non-exclusive, royalty-free, copyright license to reproduce,
35
+ prepare derivative works of, publicly display, publicly perform,
36
+ sublicense and distribute its Work and any resulting derivative
37
+ works in any form.
38
+
39
+ 3. Limitations
40
+
41
+ 3.1 Redistribution. You may reproduce or distribute the Work only
42
+ if (a) you do so under this License, (b) you include a complete
43
+ copy of this License with your distribution, and (c) you retain
44
+ without modification any copyright, patent, trademark, or
45
+ attribution notices that are present in the Work.
46
+
47
+ 3.2 Derivative Works. You may specify that additional or different
48
+ terms apply to the use, reproduction, and distribution of your
49
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
50
+ provide that the use limitation in Section 3.3 applies to your
51
+ derivative works, and (b) you identify the specific derivative
52
+ works that are subject to Your Terms. Notwithstanding Your Terms,
53
+ this License (including the redistribution requirements in Section
54
+ 3.1) will continue to apply to the Work itself.
55
+
56
+ 3.3 Use Limitation. The Work and any derivative works thereof only
57
+ may be used or intended for use non-commercially. Notwithstanding
58
+ the foregoing, NVIDIA and its affiliates may use the Work and any
59
+ derivative works commercially. As used herein, "non-commercially"
60
+ means for research or evaluation purposes only.
61
+
62
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
+ against any Licensor (including any claim, cross-claim or
64
+ counterclaim in a lawsuit) to enforce any patents that you allege
65
+ are infringed by any Work, then your rights under this License from
66
+ such Licensor (including the grant in Section 2.1) will terminate
67
+ immediately.
68
+
69
+ 3.5 Trademarks. This License does not grant any rights to use any
70
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
71
+ as necessary to reproduce the notices described in this License.
72
+
73
+ 3.6 Termination. If you violate any term of this License, then your
74
+ rights under this License (including the grant in Section 2.1) will
75
+ terminate immediately.
76
+
77
+ 4. Disclaimer of Warranty.
78
+
79
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
+ THIS LICENSE.
84
+
85
+ 5. Limitation of Liability.
86
+
87
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
+ THE POSSIBILITY OF SUCH DAMAGES.
96
+
97
+ =======================================================================
README.md CHANGED
@@ -1,12 +1,144 @@
1
  ---
2
- title: ShuangzizuoDragGAN
3
- emoji: 🌖
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: shuangzizuoDragGAN
3
+ app_file: visualizer_drag_gradio.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.35.2
 
 
6
  ---
7
+ <p align="center">
8
 
9
+ <h1 align="center">Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold</h1>
10
+ <p align="center">
11
+ <a href="https://xingangpan.github.io/"><strong>Xingang Pan</strong></a>
12
+ ·
13
+ <a href="https://ayushtewari.com/"><strong>Ayush Tewari</strong></a>
14
+ ·
15
+ <a href="https://people.mpi-inf.mpg.de/~tleimkue/"><strong>Thomas Leimkühler</strong></a>
16
+ ·
17
+ <a href="https://lingjie0206.github.io/"><strong>Lingjie Liu</strong></a>
18
+ ·
19
+ <a href="https://www.meka.page/"><strong>Abhimitra Meka</strong></a>
20
+ ·
21
+ <a href="http://www.mpi-inf.mpg.de/~theobalt/"><strong>Christian Theobalt</strong></a>
22
+ </p>
23
+ <h2 align="center">SIGGRAPH 2023 Conference Proceedings</h2>
24
+ <div align="center">
25
+ <img src="DragGAN.gif", width="600">
26
+ </div>
27
+
28
+ <p align="center">
29
+ <br>
30
+ <a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
31
+ <a href="https://twitter.com/XingangP"><img alt='Twitter' src="https://img.shields.io/twitter/follow/XingangP?label=%40XingangP"></a>
32
+ <a href="https://arxiv.org/abs/2305.10973">
33
+ <img src='https://img.shields.io/badge/Paper-PDF-green?style=for-the-badge&logo=adobeacrobatreader&logoWidth=20&logoColor=white&labelColor=66cc00&color=94DD15' alt='Paper PDF'>
34
+ </a>
35
+ <a href='https://vcai.mpi-inf.mpg.de/projects/DragGAN/'>
36
+ <img src='https://img.shields.io/badge/DragGAN-Page-orange?style=for-the-badge&logo=Google%20chrome&logoColor=white&labelColor=D35400' alt='Project Page'></a>
37
+ <a href="https://colab.research.google.com/drive/1mey-IXPwQC_qSthI5hO-LTX7QL4ivtPh?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
38
+ </p>
39
+ </p>
40
+
41
+ ## Web Demos
42
+
43
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/XingangPan/DragGAN)
44
+
45
+ <p align="left">
46
+ <a href="https://huggingface.co/spaces/radames/DragGan"><img alt="Huggingface" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DragGAN-orange"></a>
47
+ </p>
48
+
49
+ ## Requirements
50
+
51
+ If you have CUDA graphic card, please follow the requirements of [NVlabs/stylegan3](https://github.com/NVlabs/stylegan3#requirements).
52
+
53
+ The usual installation steps involve the following commands, they should set up the correct CUDA version and all the python packages
54
+
55
+ ```
56
+ conda env create -f environment.yml
57
+ conda activate stylegan3
58
+ ```
59
+
60
+ Then install the additional requirements
61
+
62
+ ```
63
+ pip install -r requirements.txt
64
+ ```
65
+
66
+ Otherwise (for GPU acceleration on MacOS with Silicon Mac M1/M2, or just CPU) try the following:
67
+
68
+ ```sh
69
+ cat environment.yml | \
70
+ grep -v -E 'nvidia|cuda' > environment-no-nvidia.yml && \
71
+ conda env create -f environment-no-nvidia.yml
72
+ conda activate stylegan3
73
+
74
+ # On MacOS
75
+ export PYTORCH_ENABLE_MPS_FALLBACK=1
76
+ ```
77
+
78
+ ## Run Gradio visualizer in Docker
79
+
80
+ Provided docker image is based on NGC PyTorch repository. To quickly try out visualizer in Docker, run the following:
81
+
82
+ ```sh
83
+ # before you build the docker container, make sure you have cloned this repo, and downloaded the pretrained model by `python scripts/download_model.py`.
84
+ docker build . -t draggan:latest
85
+ docker run -p 7860:7860 -v "$PWD":/workspace/src -it draggan:latest bash
86
+ # (Use GPU)if you want to utilize your Nvidia gpu to accelerate in docker, please add command tag `--gpus all`, like:
87
+ # docker run --gpus all -p 7860:7860 -v "$PWD":/workspace/src -it draggan:latest bash
88
+
89
+ cd src && python visualizer_drag_gradio.py --listen
90
+ ```
91
+ Now you can open a shared link from Gradio (printed in the terminal console).
92
+ Beware the Docker image takes about 25GB of disk space!
93
+
94
+ ## Download pre-trained StyleGAN2 weights
95
+
96
+ To download pre-trained weights, simply run:
97
+
98
+ ```
99
+ python scripts/download_model.py
100
+ ```
101
+ If you want to try StyleGAN-Human and the Landscapes HQ (LHQ) dataset, please download weights from these links: [StyleGAN-Human](https://drive.google.com/file/d/1dlFEHbu-WzQWJl7nBBZYcTyo000H9hVm/view?usp=sharing), [LHQ](https://drive.google.com/file/d/16twEf0T9QINAEoMsWefoWiyhcTd-aiWc/view?usp=sharing), and put them under `./checkpoints`.
102
+
103
+ Feel free to try other pretrained StyleGAN.
104
+
105
+ ## Run DragGAN GUI
106
+
107
+ To start the DragGAN GUI, simply run:
108
+ ```sh
109
+ sh scripts/gui.sh
110
+ ```
111
+ If you are using windows, you can run:
112
+ ```
113
+ .\scripts\gui.bat
114
+ ```
115
+
116
+ This GUI supports editing GAN-generated images. To edit a real image, you need to first perform GAN inversion using tools like [PTI](https://github.com/danielroich/PTI). Then load the new latent code and model weights to the GUI.
117
+
118
+ You can run DragGAN Gradio demo as well, this is universal for both windows and linux:
119
+ ```sh
120
+ python visualizer_drag_gradio.py
121
+ ```
122
+
123
+ ## Acknowledgement
124
+
125
+ This code is developed based on [StyleGAN3](https://github.com/NVlabs/stylegan3). Part of the code is borrowed from [StyleGAN-Human](https://github.com/stylegan-human/StyleGAN-Human).
126
+
127
+ (cheers to the community as well)
128
+ ## License
129
+
130
+ The code related to the DragGAN algorithm is licensed under [CC-BY-NC](https://creativecommons.org/licenses/by-nc/4.0/).
131
+ However, most of this project are available under a separate license terms: all codes used or modified from [StyleGAN3](https://github.com/NVlabs/stylegan3) is under the [Nvidia Source Code License](https://github.com/NVlabs/stylegan3/blob/main/LICENSE.txt).
132
+
133
+ Any form of use and derivative of this code must preserve the watermarking functionality showing "AI Generated".
134
+
135
+ ## BibTeX
136
+
137
+ ```bibtex
138
+ @inproceedings{pan2023draggan,
139
+ title={Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold},
140
+ author={Pan, Xingang and Tewari, Ayush, and Leimk{\"u}hler, Thomas and Liu, Lingjie and Meka, Abhimitra and Theobalt, Christian},
141
+ booktitle = {ACM SIGGRAPH 2023 Conference Proceedings},
142
+ year={2023}
143
+ }
144
+ ```
__pycache__/legacy.cpython-39.pyc ADDED
Binary file (15 kB). View file
 
arial.ttf ADDED
Binary file (276 kB). View file
 
checkpoints/stylegan2-afhqcat-512x512.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17a83abee464242f8bb40dc6363d33c2fb087066b68fc0147677fdbf21f7a7a9
3
+ size 363939583
checkpoints/stylegan2-car-config-f.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1618eee2b3ce87c4a3849442f7850ef12a478556bff035c8e09ee7e23b3794c
3
+ size 364027523
checkpoints/stylegan2-cat-config-f.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08940fd616cfaf6bc7b0286b5d1a0b3f70febb26e136d64716c8d3f5e9bd3883
3
+ size 357418027
checkpoints/stylegan2-ffhq-512x512.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2b1c92f41ce8a64c55f7a75ef06c4c0eef9e17b1eb29aae8c10fb37b3e60478
3
+ size 363939580
checkpoints/stylegan2-lhq-256x256.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae2c33d456ff3f274d472da9849c9bb515b7f0032af0d713d0d2d7f42b7942dc
3
+ size 357314829
checkpoints/stylegan2_dogs_1024_pytorch.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9e93090d02916165a602c728c0e37458fc0c58fbc58e4d75bcd096bb81c7e8c
3
+ size 381630441
checkpoints/stylegan2_elephants_512_pytorch.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56504894f6d7121959af78a74148cf1e9d858e3710312efb11c41dbf27684363
3
+ size 363965313
checkpoints/stylegan2_horses_256_pytorch.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f100dc32e2731a293f3b31a9038416f72aa5cc30555b3315a82e19c065f81b0c
3
+ size 357336721
checkpoints/stylegan2_lions_512_pytorch.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a01ff8344521171b1a2eff1e9a51c1acbc48221bdc2594919187f66a3942bcc
3
+ size 363965313
checkpoints/stylegan_human_v2_512.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7af33f638e3b3aba7bf99456eec3e9d4a022d8a7dd67683a3605a7dd37665a3b
3
+ size 352745981
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (187 Bytes). View file
 
dnnlib/__pycache__/util.cpython-39.pyc ADDED
Binary file (14 kB). View file
 
dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
environment.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: stylegan3
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python >= 3.8
7
+ - pip
8
+ - numpy>=1.25
9
+ - click>=8.0
10
+ - pillow=9.4.0
11
+ #- scipy=1.11.0
12
+ - pytorch>=2.0.1
13
+ - torchvision>=0.15.2
14
+ #- cudatoolkit=11.1
15
+ - requests=2.26.0
16
+ - tqdm=4.62.2
17
+ - ninja=1.10.2
18
+ - matplotlib=3.4.2
19
+ - imageio=2.9.0
20
+ - pip:
21
+ - imgui==2.0.0
22
+ - glfw==2.6.1
23
+ - gradio==3.35.2
24
+ - pyopengl==3.1.5
25
+ - imageio-ffmpeg==0.4.3
26
+ # pyspng is currently broken on MacOS (see https://github.com/nurpax/pyspng/pull/6 for instance)
27
+ - pyspng-seunglab
gen_images.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate images using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+
21
+ import legacy
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def parse_range(s: Union[str, List]) -> List[int]:
26
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
27
+
28
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
29
+ '''
30
+ if isinstance(s, list): return s
31
+ ranges = []
32
+ range_re = re.compile(r'^(\d+)-(\d+)$')
33
+ for p in s.split(','):
34
+ m = range_re.match(p)
35
+ if m:
36
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
37
+ else:
38
+ ranges.append(int(p))
39
+ return ranges
40
+
41
+ #----------------------------------------------------------------------------
42
+
43
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
44
+ '''Parse a floating point 2-vector of syntax 'a,b'.
45
+
46
+ Example:
47
+ '0,1' returns (0,1)
48
+ '''
49
+ if isinstance(s, tuple): return s
50
+ parts = s.split(',')
51
+ if len(parts) == 2:
52
+ return (float(parts[0]), float(parts[1]))
53
+ raise ValueError(f'cannot parse 2-vector {s}')
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def make_transform(translate: Tuple[float,float], angle: float):
58
+ m = np.eye(3)
59
+ s = np.sin(angle/360.0*np.pi*2)
60
+ c = np.cos(angle/360.0*np.pi*2)
61
+ m[0][0] = c
62
+ m[0][1] = s
63
+ m[0][2] = translate[0]
64
+ m[1][0] = -s
65
+ m[1][1] = c
66
+ m[1][2] = translate[1]
67
+ return m
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+ @click.command()
72
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
73
+ @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
74
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
75
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
76
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
77
+ @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
78
+ @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
79
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
80
+ def generate_images(
81
+ network_pkl: str,
82
+ seeds: List[int],
83
+ truncation_psi: float,
84
+ noise_mode: str,
85
+ outdir: str,
86
+ translate: Tuple[float,float],
87
+ rotate: float,
88
+ class_idx: Optional[int]
89
+ ):
90
+ """Generate images using pretrained network pickle.
91
+
92
+ Examples:
93
+
94
+ \b
95
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
96
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
97
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
98
+
99
+ \b
100
+ # Generate uncurated images with truncation using the MetFaces-U dataset
101
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
102
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
103
+ """
104
+
105
+ print('Loading networks from "%s"...' % network_pkl)
106
+ device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
107
+ dtype = torch.float32 if device.type == 'mps' else torch.float64
108
+ with dnnlib.util.open_url(network_pkl) as f:
109
+ G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
110
+ # import pickle
111
+ # G = legacy.load_network_pkl(f)
112
+ # output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb')
113
+ # pickle.dump(G, output)
114
+
115
+ os.makedirs(outdir, exist_ok=True)
116
+
117
+ # Labels.
118
+ label = torch.zeros([1, G.c_dim], device=device)
119
+ if G.c_dim != 0:
120
+ if class_idx is None:
121
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
122
+ label[:, class_idx] = 1
123
+ else:
124
+ if class_idx is not None:
125
+ print ('warn: --class=lbl ignored when running on an unconditional network')
126
+
127
+ # Generate images.
128
+ for seed_idx, seed in enumerate(seeds):
129
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
130
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
131
+
132
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
133
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
134
+ # operations in the network.
135
+ if hasattr(G.synthesis, 'input'):
136
+ m = make_transform(translate, rotate)
137
+ m = np.linalg.inv(m)
138
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
139
+
140
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
141
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
142
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
143
+
144
+
145
+ #----------------------------------------------------------------------------
146
+
147
+ if __name__ == "__main__":
148
+ generate_images() # pylint: disable=no-value-for-parameter
149
+
150
+ #----------------------------------------------------------------------------
gradio_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
2
+ get_latest_points_pair, get_valid_mask,
3
+ on_change_single_global_state)
4
+
5
+ __all__ = [
6
+ 'draw_mask_on_image', 'draw_points_on_image',
7
+ 'on_change_single_global_state', 'get_latest_points_pair',
8
+ 'get_valid_mask', 'ImageMask'
9
+ ]
gradio_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (374 Bytes). View file
 
gradio_utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.67 kB). View file
 
gradio_utils/utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+
5
+
6
+ class ImageMask(gr.components.Image):
7
+ """
8
+ Sets: source="canvas", tool="sketch"
9
+ """
10
+
11
+ is_template = True
12
+
13
+ def __init__(self, **kwargs):
14
+ super().__init__(source="upload",
15
+ tool="sketch",
16
+ interactive=False,
17
+ **kwargs)
18
+
19
+ def preprocess(self, x):
20
+ if x is None:
21
+ return x
22
+ if self.tool == "sketch" and self.source in ["upload", "webcam"
23
+ ] and type(x) != dict:
24
+ decode_image = gr.processing_utils.decode_base64_to_image(x)
25
+ width, height = decode_image.size
26
+ mask = np.ones((height, width, 4), dtype=np.uint8)
27
+ mask[..., -1] = 255
28
+ mask = self.postprocess(mask)
29
+ x = {'image': x, 'mask': mask}
30
+ return super().preprocess(x)
31
+
32
+
33
+ def get_valid_mask(mask: np.ndarray):
34
+ """Convert mask from gr.Image(0 to 255, RGBA) to binary mask.
35
+ """
36
+ if mask.ndim == 3:
37
+ mask_pil = Image.fromarray(mask).convert('L')
38
+ mask = np.array(mask_pil)
39
+ if mask.max() == 255:
40
+ mask = mask / 255
41
+ return mask
42
+
43
+
44
+ def draw_points_on_image(image,
45
+ points,
46
+ curr_point=None,
47
+ highlight_all=True,
48
+ radius_scale=0.01):
49
+ overlay_rgba = Image.new("RGBA", image.size, 0)
50
+ overlay_draw = ImageDraw.Draw(overlay_rgba)
51
+ for point_key, point in points.items():
52
+ if ((curr_point is not None and curr_point == point_key)
53
+ or highlight_all):
54
+ p_color = (255, 0, 0)
55
+ t_color = (0, 0, 255)
56
+
57
+ else:
58
+ p_color = (255, 0, 0, 35)
59
+ t_color = (0, 0, 255, 35)
60
+
61
+ rad_draw = int(image.size[0] * radius_scale)
62
+
63
+ p_start = point.get("start_temp", point["start"])
64
+ p_target = point["target"]
65
+
66
+ if p_start is not None and p_target is not None:
67
+ p_draw = int(p_start[0]), int(p_start[1])
68
+ t_draw = int(p_target[0]), int(p_target[1])
69
+
70
+ overlay_draw.line(
71
+ (p_draw[0], p_draw[1], t_draw[0], t_draw[1]),
72
+ fill=(255, 255, 0),
73
+ width=2,
74
+ )
75
+
76
+ if p_start is not None:
77
+ p_draw = int(p_start[0]), int(p_start[1])
78
+ overlay_draw.ellipse(
79
+ (
80
+ p_draw[0] - rad_draw,
81
+ p_draw[1] - rad_draw,
82
+ p_draw[0] + rad_draw,
83
+ p_draw[1] + rad_draw,
84
+ ),
85
+ fill=p_color,
86
+ )
87
+
88
+ if curr_point is not None and curr_point == point_key:
89
+ # overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0))
90
+ overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0))
91
+
92
+ if p_target is not None:
93
+ t_draw = int(p_target[0]), int(p_target[1])
94
+ overlay_draw.ellipse(
95
+ (
96
+ t_draw[0] - rad_draw,
97
+ t_draw[1] - rad_draw,
98
+ t_draw[0] + rad_draw,
99
+ t_draw[1] + rad_draw,
100
+ ),
101
+ fill=t_color,
102
+ )
103
+
104
+ if curr_point is not None and curr_point == point_key:
105
+ # overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0))
106
+ overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0))
107
+
108
+ return Image.alpha_composite(image.convert("RGBA"),
109
+ overlay_rgba).convert("RGB")
110
+
111
+
112
+ def draw_mask_on_image(image, mask):
113
+ im_mask = np.uint8(mask * 255)
114
+ im_mask_rgba = np.concatenate(
115
+ (
116
+ np.tile(im_mask[..., None], [1, 1, 3]),
117
+ 45 * np.ones(
118
+ (im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8),
119
+ ),
120
+ axis=-1,
121
+ )
122
+ im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA")
123
+
124
+ return Image.alpha_composite(image.convert("RGBA"),
125
+ im_mask_rgba).convert("RGB")
126
+
127
+
128
+ def on_change_single_global_state(keys,
129
+ value,
130
+ global_state,
131
+ map_transform=None):
132
+ if map_transform is not None:
133
+ value = map_transform(value)
134
+
135
+ curr_state = global_state
136
+ if isinstance(keys, str):
137
+ last_key = keys
138
+
139
+ else:
140
+ for k in keys[:-1]:
141
+ curr_state = curr_state[k]
142
+
143
+ last_key = keys[-1]
144
+
145
+ curr_state[last_key] = value
146
+ return global_state
147
+
148
+
149
+ def get_latest_points_pair(points_dict):
150
+ if not points_dict:
151
+ return None
152
+ point_idx = list(points_dict.keys())
153
+ latest_point_idx = max(point_idx)
154
+ return latest_point_idx
gui_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
gui_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (123 Bytes). View file
 
gui_utils/__pycache__/gl_utils.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
gui_utils/__pycache__/glfw_window.cpython-39.pyc ADDED
Binary file (7.72 kB). View file
 
gui_utils/__pycache__/imgui_utils.cpython-39.pyc ADDED
Binary file (5.78 kB). View file
 
gui_utils/__pycache__/imgui_window.cpython-39.pyc ADDED
Binary file (3.93 kB). View file
 
gui_utils/__pycache__/text_utils.cpython-39.pyc ADDED
Binary file (5.09 kB). View file
 
gui_utils/gl_utils.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import math
10
+ import os
11
+ import functools
12
+ import contextlib
13
+ import numpy as np
14
+ import OpenGL.GL as gl
15
+ import OpenGL.GL.ARB.texture_float
16
+ import dnnlib
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def init_egl():
21
+ assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
22
+ import OpenGL.EGL as egl
23
+ import ctypes
24
+
25
+ # Initialize EGL.
26
+ display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
27
+ assert display != egl.EGL_NO_DISPLAY
28
+ major = ctypes.c_int32()
29
+ minor = ctypes.c_int32()
30
+ ok = egl.eglInitialize(display, major, minor)
31
+ assert ok
32
+ assert major.value * 10 + minor.value >= 14
33
+
34
+ # Choose config.
35
+ config_attribs = [
36
+ egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
37
+ egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
38
+ egl.EGL_NONE
39
+ ]
40
+ configs = (ctypes.c_int32 * 1)()
41
+ num_configs = ctypes.c_int32()
42
+ ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
43
+ assert ok
44
+ assert num_configs.value == 1
45
+ config = configs[0]
46
+
47
+ # Create dummy pbuffer surface.
48
+ surface_attribs = [
49
+ egl.EGL_WIDTH, 1,
50
+ egl.EGL_HEIGHT, 1,
51
+ egl.EGL_NONE
52
+ ]
53
+ surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
54
+ assert surface != egl.EGL_NO_SURFACE
55
+
56
+ # Setup GL context.
57
+ ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
58
+ assert ok
59
+ context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
60
+ assert context != egl.EGL_NO_CONTEXT
61
+ ok = egl.eglMakeCurrent(display, surface, surface, context)
62
+ assert ok
63
+
64
+ #----------------------------------------------------------------------------
65
+
66
+ _texture_formats = {
67
+ ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
68
+ ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
69
+ ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
70
+ ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
71
+ ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
72
+ ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
73
+ ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
74
+ ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
75
+ }
76
+
77
+ def get_texture_format(dtype, channels):
78
+ return _texture_formats[(np.dtype(dtype).name, int(channels))]
79
+
80
+ #----------------------------------------------------------------------------
81
+
82
+ def prepare_texture_data(image):
83
+ image = np.asarray(image)
84
+ if image.ndim == 2:
85
+ image = image[:, :, np.newaxis]
86
+ if image.dtype.name == 'float64':
87
+ image = image.astype('float32')
88
+ return image
89
+
90
+ #----------------------------------------------------------------------------
91
+
92
+ def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
93
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
94
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
95
+ align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
96
+ image = prepare_texture_data(image)
97
+ height, width, channels = image.shape
98
+ size = zoom * [width, height]
99
+ pos = pos - size * align
100
+ if rint:
101
+ pos = np.rint(pos)
102
+ fmt = get_texture_format(image.dtype, channels)
103
+
104
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
105
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
106
+ gl.glRasterPos2f(pos[0], pos[1])
107
+ gl.glPixelZoom(zoom[0], -zoom[1])
108
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
109
+ gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
110
+ gl.glPopClientAttrib()
111
+ gl.glPopAttrib()
112
+
113
+ #----------------------------------------------------------------------------
114
+
115
+ def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
116
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
117
+ dtype = np.dtype(dtype)
118
+ fmt = get_texture_format(dtype, channels)
119
+ image = np.empty([height, width, channels], dtype=dtype)
120
+
121
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
122
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
123
+ gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
124
+ gl.glPopClientAttrib()
125
+ return np.flipud(image)
126
+
127
+ #----------------------------------------------------------------------------
128
+
129
+ class Texture:
130
+ def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
131
+ self.gl_id = None
132
+ self.bilinear = bilinear
133
+ self.mipmap = mipmap
134
+
135
+ # Determine size and dtype.
136
+ if image is not None:
137
+ image = prepare_texture_data(image)
138
+ self.height, self.width, self.channels = image.shape
139
+ self.dtype = image.dtype
140
+ else:
141
+ assert width is not None and height is not None
142
+ self.width = width
143
+ self.height = height
144
+ self.channels = channels if channels is not None else 3
145
+ self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
146
+
147
+ # Validate size and dtype.
148
+ assert isinstance(self.width, int) and self.width >= 0
149
+ assert isinstance(self.height, int) and self.height >= 0
150
+ assert isinstance(self.channels, int) and self.channels >= 1
151
+ assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
152
+
153
+ # Create texture object.
154
+ self.gl_id = gl.glGenTextures(1)
155
+ with self.bind():
156
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
157
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
158
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
159
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
160
+ self.update(image)
161
+
162
+ def delete(self):
163
+ if self.gl_id is not None:
164
+ gl.glDeleteTextures([self.gl_id])
165
+ self.gl_id = None
166
+
167
+ def __del__(self):
168
+ try:
169
+ self.delete()
170
+ except:
171
+ pass
172
+
173
+ @contextlib.contextmanager
174
+ def bind(self):
175
+ prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
176
+ gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
177
+ yield
178
+ gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
179
+
180
+ def update(self, image):
181
+ if image is not None:
182
+ image = prepare_texture_data(image)
183
+ assert self.is_compatible(image=image)
184
+ with self.bind():
185
+ fmt = get_texture_format(self.dtype, self.channels)
186
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
187
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
188
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
189
+ if self.mipmap:
190
+ gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
191
+ gl.glPopClientAttrib()
192
+
193
+ def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
194
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
195
+ size = zoom * [self.width, self.height]
196
+ with self.bind():
197
+ gl.glPushAttrib(gl.GL_ENABLE_BIT)
198
+ gl.glEnable(gl.GL_TEXTURE_2D)
199
+ draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
200
+ gl.glPopAttrib()
201
+
202
+ def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
203
+ if image is not None:
204
+ if image.ndim != 3:
205
+ return False
206
+ ih, iw, ic = image.shape
207
+ if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
208
+ return False
209
+ if width is not None and self.width != width:
210
+ return False
211
+ if height is not None and self.height != height:
212
+ return False
213
+ if channels is not None and self.channels != channels:
214
+ return False
215
+ if dtype is not None and self.dtype != dtype:
216
+ return False
217
+ return True
218
+
219
+ #----------------------------------------------------------------------------
220
+
221
+ class Framebuffer:
222
+ def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
223
+ self.texture = texture
224
+ self.gl_id = None
225
+ self.gl_color = None
226
+ self.gl_depth_stencil = None
227
+ self.msaa = msaa
228
+
229
+ # Determine size and dtype.
230
+ if texture is not None:
231
+ assert isinstance(self.texture, Texture)
232
+ self.width = texture.width
233
+ self.height = texture.height
234
+ self.channels = texture.channels
235
+ self.dtype = texture.dtype
236
+ else:
237
+ assert width is not None and height is not None
238
+ self.width = width
239
+ self.height = height
240
+ self.channels = channels if channels is not None else 4
241
+ self.dtype = np.dtype(dtype) if dtype is not None else np.float32
242
+
243
+ # Validate size and dtype.
244
+ assert isinstance(self.width, int) and self.width >= 0
245
+ assert isinstance(self.height, int) and self.height >= 0
246
+ assert isinstance(self.channels, int) and self.channels >= 1
247
+ assert width is None or width == self.width
248
+ assert height is None or height == self.height
249
+ assert channels is None or channels == self.channels
250
+ assert dtype is None or dtype == self.dtype
251
+
252
+ # Create framebuffer object.
253
+ self.gl_id = gl.glGenFramebuffers(1)
254
+ with self.bind():
255
+
256
+ # Setup color buffer.
257
+ if self.texture is not None:
258
+ assert self.msaa == 0
259
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
260
+ else:
261
+ fmt = get_texture_format(self.dtype, self.channels)
262
+ self.gl_color = gl.glGenRenderbuffers(1)
263
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
264
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
265
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
266
+
267
+ # Setup depth/stencil buffer.
268
+ self.gl_depth_stencil = gl.glGenRenderbuffers(1)
269
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
270
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
271
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
272
+
273
+ def delete(self):
274
+ if self.gl_id is not None:
275
+ gl.glDeleteFramebuffers([self.gl_id])
276
+ self.gl_id = None
277
+ if self.gl_color is not None:
278
+ gl.glDeleteRenderbuffers(1, [self.gl_color])
279
+ self.gl_color = None
280
+ if self.gl_depth_stencil is not None:
281
+ gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
282
+ self.gl_depth_stencil = None
283
+
284
+ def __del__(self):
285
+ try:
286
+ self.delete()
287
+ except:
288
+ pass
289
+
290
+ @contextlib.contextmanager
291
+ def bind(self):
292
+ prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
293
+ prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
294
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
295
+ if self.width is not None and self.height is not None:
296
+ gl.glViewport(0, 0, self.width, self.height)
297
+ yield
298
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
299
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
300
+
301
+ def blit(self, dst=None):
302
+ assert dst is None or isinstance(dst, Framebuffer)
303
+ with self.bind():
304
+ gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
305
+ gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
306
+
307
+ #----------------------------------------------------------------------------
308
+
309
+ def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
310
+ assert vertices.ndim == 2 and vertices.shape[1] == 2
311
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
312
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
313
+ color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
314
+ alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
315
+
316
+ gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
317
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
318
+ gl.glMatrixMode(gl.GL_MODELVIEW)
319
+ gl.glPushMatrix()
320
+
321
+ gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
322
+ gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
323
+ gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
324
+ gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
325
+ gl.glTranslate(pos[0], pos[1], 0)
326
+ gl.glScale(size[0], size[1], 1)
327
+ gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
328
+ gl.glDrawArrays(mode, 0, vertices.shape[0])
329
+
330
+ gl.glPopMatrix()
331
+ gl.glPopAttrib()
332
+ gl.glPopClientAttrib()
333
+
334
+ #----------------------------------------------------------------------------
335
+
336
+ def draw_arrow(x1, y1, x2, y2, l=10, width=1.0):
337
+ # Compute the length and angle of the arrow
338
+ dx = x2 - x1
339
+ dy = y2 - y1
340
+ length = math.sqrt(dx**2 + dy**2)
341
+ if length < l:
342
+ return
343
+ angle = math.atan2(dy, dx)
344
+
345
+ # Save the current modelview matrix
346
+ gl.glPushMatrix()
347
+
348
+ # Translate and rotate the coordinate system
349
+ gl.glTranslatef(x1, y1, 0.0)
350
+ gl.glRotatef(angle * 180.0 / math.pi, 0.0, 0.0, 1.0)
351
+
352
+ # Set the line width
353
+ gl.glLineWidth(width)
354
+ # gl.glColor3f(0.75, 0.75, 0.75)
355
+
356
+ # Begin drawing lines
357
+ gl.glBegin(gl.GL_LINES)
358
+
359
+ # Draw the shaft of the arrow
360
+ gl.glVertex2f(0.0, 0.0)
361
+ gl.glVertex2f(length, 0.0)
362
+
363
+ # Draw the head of the arrow
364
+ gl.glVertex2f(length, 0.0)
365
+ gl.glVertex2f(length - 2 * l, l)
366
+ gl.glVertex2f(length, 0.0)
367
+ gl.glVertex2f(length - 2 * l, -l)
368
+
369
+ # End drawing lines
370
+ gl.glEnd()
371
+
372
+ # Restore the modelview matrix
373
+ gl.glPopMatrix()
374
+
375
+ #----------------------------------------------------------------------------
376
+
377
+ def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
378
+ assert pos2 is None or size is None
379
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
380
+ pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
381
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
382
+ size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
383
+ pos = pos - size * align
384
+ if rint:
385
+ pos = np.rint(pos)
386
+ rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
387
+ rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
388
+ if np.min(rounding) == 0:
389
+ rounding *= 0
390
+ vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
391
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
392
+
393
+ @functools.lru_cache(maxsize=10000)
394
+ def _setup_rect(rx, ry):
395
+ t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
396
+ s = 1 - np.sin(t); c = 1 - np.cos(t)
397
+ x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
398
+ y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
399
+ v = np.stack([x, y], axis=-1).reshape(-1, 2)
400
+ return v.astype('float32')
401
+
402
+ #----------------------------------------------------------------------------
403
+
404
+ def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
405
+ hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
406
+ vertices = _setup_circle(float(hole))
407
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
408
+
409
+ @functools.lru_cache(maxsize=10000)
410
+ def _setup_circle(hole):
411
+ t = np.linspace(0, np.pi * 2, 128)
412
+ s = np.sin(t); c = np.cos(t)
413
+ v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
414
+ return v.astype('float32')
415
+
416
+ #----------------------------------------------------------------------------
gui_utils/glfw_window.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import time
10
+ import glfw
11
+ import OpenGL.GL as gl
12
+ from . import gl_utils
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ class GlfwWindow: # pylint: disable=too-many-public-methods
17
+ def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
18
+ self._glfw_window = None
19
+ self._drawing_frame = False
20
+ self._frame_start_time = None
21
+ self._frame_delta = 0
22
+ self._fps_limit = None
23
+ self._vsync = None
24
+ self._skip_frames = 0
25
+ self._deferred_show = deferred_show
26
+ self._close_on_esc = close_on_esc
27
+ self._esc_pressed = False
28
+ self._drag_and_drop_paths = None
29
+ self._capture_next_frame = False
30
+ self._captured_frame = None
31
+
32
+ # Create window.
33
+ glfw.init()
34
+ glfw.window_hint(glfw.VISIBLE, False)
35
+ self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
36
+ self._attach_glfw_callbacks()
37
+ self.make_context_current()
38
+
39
+ # Adjust window.
40
+ self.set_vsync(False)
41
+ self.set_window_size(window_width, window_height)
42
+ if not self._deferred_show:
43
+ glfw.show_window(self._glfw_window)
44
+
45
+ def close(self):
46
+ if self._drawing_frame:
47
+ self.end_frame()
48
+ if self._glfw_window is not None:
49
+ glfw.destroy_window(self._glfw_window)
50
+ self._glfw_window = None
51
+ #glfw.terminate() # Commented out to play it nice with other glfw clients.
52
+
53
+ def __del__(self):
54
+ try:
55
+ self.close()
56
+ except:
57
+ pass
58
+
59
+ @property
60
+ def window_width(self):
61
+ return self.content_width
62
+
63
+ @property
64
+ def window_height(self):
65
+ return self.content_height + self.title_bar_height
66
+
67
+ @property
68
+ def content_width(self):
69
+ width, _height = glfw.get_window_size(self._glfw_window)
70
+ return width
71
+
72
+ @property
73
+ def content_height(self):
74
+ _width, height = glfw.get_window_size(self._glfw_window)
75
+ return height
76
+
77
+ @property
78
+ def title_bar_height(self):
79
+ _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
80
+ return top
81
+
82
+ @property
83
+ def monitor_width(self):
84
+ _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
85
+ return width
86
+
87
+ @property
88
+ def monitor_height(self):
89
+ _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
90
+ return height
91
+
92
+ @property
93
+ def frame_delta(self):
94
+ return self._frame_delta
95
+
96
+ def set_title(self, title):
97
+ glfw.set_window_title(self._glfw_window, title)
98
+
99
+ def set_window_size(self, width, height):
100
+ width = min(width, self.monitor_width)
101
+ height = min(height, self.monitor_height)
102
+ glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
103
+ if width == self.monitor_width and height == self.monitor_height:
104
+ self.maximize()
105
+
106
+ def set_content_size(self, width, height):
107
+ self.set_window_size(width, height + self.title_bar_height)
108
+
109
+ def maximize(self):
110
+ glfw.maximize_window(self._glfw_window)
111
+
112
+ def set_position(self, x, y):
113
+ glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
114
+
115
+ def center(self):
116
+ self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
117
+
118
+ def set_vsync(self, vsync):
119
+ vsync = bool(vsync)
120
+ if vsync != self._vsync:
121
+ glfw.swap_interval(1 if vsync else 0)
122
+ self._vsync = vsync
123
+
124
+ def set_fps_limit(self, fps_limit):
125
+ self._fps_limit = int(fps_limit)
126
+
127
+ def should_close(self):
128
+ return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
129
+
130
+ def skip_frame(self):
131
+ self.skip_frames(1)
132
+
133
+ def skip_frames(self, num): # Do not update window for the next N frames.
134
+ self._skip_frames = max(self._skip_frames, int(num))
135
+
136
+ def is_skipping_frames(self):
137
+ return self._skip_frames > 0
138
+
139
+ def capture_next_frame(self):
140
+ self._capture_next_frame = True
141
+
142
+ def pop_captured_frame(self):
143
+ frame = self._captured_frame
144
+ self._captured_frame = None
145
+ return frame
146
+
147
+ def pop_drag_and_drop_paths(self):
148
+ paths = self._drag_and_drop_paths
149
+ self._drag_and_drop_paths = None
150
+ return paths
151
+
152
+ def draw_frame(self): # To be overridden by subclass.
153
+ self.begin_frame()
154
+ # Rendering code goes here.
155
+ self.end_frame()
156
+
157
+ def make_context_current(self):
158
+ if self._glfw_window is not None:
159
+ glfw.make_context_current(self._glfw_window)
160
+
161
+ def begin_frame(self):
162
+ # End previous frame.
163
+ if self._drawing_frame:
164
+ self.end_frame()
165
+
166
+ # Apply FPS limit.
167
+ if self._frame_start_time is not None and self._fps_limit is not None:
168
+ delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
169
+ if delay > 0:
170
+ time.sleep(delay)
171
+ cur_time = time.perf_counter()
172
+ if self._frame_start_time is not None:
173
+ self._frame_delta = cur_time - self._frame_start_time
174
+ self._frame_start_time = cur_time
175
+
176
+ # Process events.
177
+ glfw.poll_events()
178
+
179
+ # Begin frame.
180
+ self._drawing_frame = True
181
+ self.make_context_current()
182
+
183
+ # Initialize GL state.
184
+ gl.glViewport(0, 0, self.content_width, self.content_height)
185
+ gl.glMatrixMode(gl.GL_PROJECTION)
186
+ gl.glLoadIdentity()
187
+ gl.glTranslate(-1, 1, 0)
188
+ gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
189
+ gl.glMatrixMode(gl.GL_MODELVIEW)
190
+ gl.glLoadIdentity()
191
+ gl.glEnable(gl.GL_BLEND)
192
+ gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
193
+
194
+ # Clear.
195
+ gl.glClearColor(0, 0, 0, 1)
196
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
197
+
198
+ def end_frame(self):
199
+ assert self._drawing_frame
200
+ self._drawing_frame = False
201
+
202
+ # Skip frames if requested.
203
+ if self._skip_frames > 0:
204
+ self._skip_frames -= 1
205
+ return
206
+
207
+ # Capture frame if requested.
208
+ if self._capture_next_frame:
209
+ self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
210
+ self._capture_next_frame = False
211
+
212
+ # Update window.
213
+ if self._deferred_show:
214
+ glfw.show_window(self._glfw_window)
215
+ self._deferred_show = False
216
+ glfw.swap_buffers(self._glfw_window)
217
+
218
+ def _attach_glfw_callbacks(self):
219
+ glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
220
+ glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
221
+
222
+ def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
223
+ if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
224
+ self._esc_pressed = True
225
+
226
+ def _glfw_drop_callback(self, _window, paths):
227
+ self._drag_and_drop_paths = paths
228
+
229
+ #----------------------------------------------------------------------------
gui_utils/imgui_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import contextlib
10
+ import imgui
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
15
+ s = imgui.get_style()
16
+ s.window_padding = [spacing, spacing]
17
+ s.item_spacing = [spacing, spacing]
18
+ s.item_inner_spacing = [spacing, spacing]
19
+ s.columns_min_spacing = spacing
20
+ s.indent_spacing = indent
21
+ s.scrollbar_size = scrollbar
22
+ s.frame_padding = [4, 3]
23
+ s.window_border_size = 1
24
+ s.child_border_size = 1
25
+ s.popup_border_size = 1
26
+ s.frame_border_size = 1
27
+ s.window_rounding = 0
28
+ s.child_rounding = 0
29
+ s.popup_rounding = 3
30
+ s.frame_rounding = 3
31
+ s.scrollbar_rounding = 3
32
+ s.grab_rounding = 3
33
+
34
+ getattr(imgui, f'style_colors_{color_scheme}')(s)
35
+ c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
36
+ c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
37
+ s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ @contextlib.contextmanager
42
+ def grayed_out(cond=True):
43
+ if cond:
44
+ s = imgui.get_style()
45
+ text = s.colors[imgui.COLOR_TEXT_DISABLED]
46
+ grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
47
+ back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
48
+ imgui.push_style_color(imgui.COLOR_TEXT, *text)
49
+ imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
50
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
51
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
52
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
53
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
54
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
55
+ imgui.push_style_color(imgui.COLOR_BUTTON, *back)
56
+ imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
57
+ imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
58
+ imgui.push_style_color(imgui.COLOR_HEADER, *back)
59
+ imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
60
+ imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
61
+ imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
62
+ yield
63
+ imgui.pop_style_color(14)
64
+ else:
65
+ yield
66
+
67
+ #----------------------------------------------------------------------------
68
+
69
+ @contextlib.contextmanager
70
+ def item_width(width=None):
71
+ if width is not None:
72
+ imgui.push_item_width(width)
73
+ yield
74
+ imgui.pop_item_width()
75
+ else:
76
+ yield
77
+
78
+ #----------------------------------------------------------------------------
79
+
80
+ def scoped_by_object_id(method):
81
+ def decorator(self, *args, **kwargs):
82
+ imgui.push_id(str(id(self)))
83
+ res = method(self, *args, **kwargs)
84
+ imgui.pop_id()
85
+ return res
86
+ return decorator
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ def button(label, width=0, enabled=True):
91
+ with grayed_out(not enabled):
92
+ clicked = imgui.button(label, width=width)
93
+ clicked = clicked and enabled
94
+ return clicked
95
+
96
+ #----------------------------------------------------------------------------
97
+
98
+ def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
99
+ expanded = False
100
+ if show:
101
+ if default:
102
+ flags |= imgui.TREE_NODE_DEFAULT_OPEN
103
+ if not enabled:
104
+ flags |= imgui.TREE_NODE_LEAF
105
+ with grayed_out(not enabled):
106
+ expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
107
+ expanded = expanded and enabled
108
+ return expanded, visible
109
+
110
+ #----------------------------------------------------------------------------
111
+
112
+ def popup_button(label, width=0, enabled=True):
113
+ if button(label, width, enabled):
114
+ imgui.open_popup(label)
115
+ opened = imgui.begin_popup(label)
116
+ return opened
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ def input_text(label, value, buffer_length, flags, width=None, help_text=''):
121
+ old_value = value
122
+ color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
123
+ if value == '':
124
+ color[-1] *= 0.5
125
+ with item_width(width):
126
+ imgui.push_style_color(imgui.COLOR_TEXT, *color)
127
+ value = value if value != '' else help_text
128
+ changed, value = imgui.input_text(label, value, buffer_length, flags)
129
+ value = value if value != help_text else ''
130
+ imgui.pop_style_color(1)
131
+ if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
132
+ changed = (value != old_value)
133
+ return changed, value
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ def drag_previous_control(enabled=True):
138
+ dragging = False
139
+ dx = 0
140
+ dy = 0
141
+ if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
142
+ if enabled:
143
+ dragging = True
144
+ dx, dy = imgui.get_mouse_drag_delta()
145
+ imgui.reset_mouse_drag_delta()
146
+ imgui.end_drag_drop_source()
147
+ return dragging, dx, dy
148
+
149
+ #----------------------------------------------------------------------------
150
+
151
+ def drag_button(label, width=0, enabled=True):
152
+ clicked = button(label, width=width, enabled=enabled)
153
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
154
+ return clicked, dragging, dx, dy
155
+
156
+ #----------------------------------------------------------------------------
157
+
158
+ def drag_hidden_window(label, x, y, width, height, enabled=True):
159
+ imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
160
+ imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
161
+ imgui.set_next_window_position(x, y)
162
+ imgui.set_next_window_size(width, height)
163
+ imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
164
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
165
+ imgui.end()
166
+ imgui.pop_style_color(2)
167
+ return dragging, dx, dy
168
+
169
+ #----------------------------------------------------------------------------
170
+
171
+ def click_hidden_window(label, x, y, width, height, img_w, img_h, enabled=True):
172
+ imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
173
+ imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
174
+ imgui.set_next_window_position(x, y)
175
+ imgui.set_next_window_size(width, height)
176
+ imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
177
+ clicked, down = False, False
178
+ img_x, img_y = 0, 0
179
+ if imgui.is_mouse_down():
180
+ posx, posy = imgui.get_mouse_pos()
181
+ if posx >= x and posx < x + width and posy >= y and posy < y + height:
182
+ if imgui.is_mouse_clicked():
183
+ clicked = True
184
+ down = True
185
+ img_x = round((posx - x) / (width - 1) * (img_w - 1))
186
+ img_y = round((posy - y) / (height - 1) * (img_h - 1))
187
+ imgui.end()
188
+ imgui.pop_style_color(2)
189
+ return clicked, down, img_x, img_y
190
+
191
+ #----------------------------------------------------------------------------
gui_utils/imgui_window.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import imgui
11
+ import imgui.integrations.glfw
12
+
13
+ from . import glfw_window
14
+ from . import imgui_utils
15
+ from . import text_utils
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ class ImguiWindow(glfw_window.GlfwWindow):
20
+ def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
21
+ if font is None:
22
+ font = text_utils.get_default_font()
23
+ font_sizes = {int(size) for size in font_sizes}
24
+ super().__init__(title=title, **glfw_kwargs)
25
+
26
+ # Init fields.
27
+ self._imgui_context = None
28
+ self._imgui_renderer = None
29
+ self._imgui_fonts = None
30
+ self._cur_font_size = max(font_sizes)
31
+
32
+ # Delete leftover imgui.ini to avoid unexpected behavior.
33
+ if os.path.isfile('imgui.ini'):
34
+ os.remove('imgui.ini')
35
+
36
+ # Init ImGui.
37
+ self._imgui_context = imgui.create_context()
38
+ self._imgui_renderer = _GlfwRenderer(self._glfw_window)
39
+ self._attach_glfw_callbacks()
40
+ imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
41
+ imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
42
+ self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
43
+ self._imgui_renderer.refresh_font_texture()
44
+
45
+ def close(self):
46
+ self.make_context_current()
47
+ self._imgui_fonts = None
48
+ if self._imgui_renderer is not None:
49
+ self._imgui_renderer.shutdown()
50
+ self._imgui_renderer = None
51
+ if self._imgui_context is not None:
52
+ #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
53
+ self._imgui_context = None
54
+ super().close()
55
+
56
+ def _glfw_key_callback(self, *args):
57
+ super()._glfw_key_callback(*args)
58
+ self._imgui_renderer.keyboard_callback(*args)
59
+
60
+ @property
61
+ def font_size(self):
62
+ return self._cur_font_size
63
+
64
+ @property
65
+ def spacing(self):
66
+ return round(self._cur_font_size * 0.4)
67
+
68
+ def set_font_size(self, target): # Applied on next frame.
69
+ self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
70
+
71
+ def begin_frame(self):
72
+ # Begin glfw frame.
73
+ super().begin_frame()
74
+
75
+ # Process imgui events.
76
+ self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
77
+ if self.content_width > 0 and self.content_height > 0:
78
+ self._imgui_renderer.process_inputs()
79
+
80
+ # Begin imgui frame.
81
+ imgui.new_frame()
82
+ imgui.push_font(self._imgui_fonts[self._cur_font_size])
83
+ imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
84
+
85
+ def end_frame(self):
86
+ imgui.pop_font()
87
+ imgui.render()
88
+ imgui.end_frame()
89
+ self._imgui_renderer.render(imgui.get_draw_data())
90
+ super().end_frame()
91
+
92
+ #----------------------------------------------------------------------------
93
+ # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
94
+
95
+ class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
96
+ def __init__(self, *args, **kwargs):
97
+ super().__init__(*args, **kwargs)
98
+ self.mouse_wheel_multiplier = 1
99
+
100
+ def scroll_callback(self, window, x_offset, y_offset):
101
+ self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
102
+
103
+ #----------------------------------------------------------------------------
gui_utils/text_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import functools
10
+ from typing import Optional
11
+
12
+ import dnnlib
13
+ import numpy as np
14
+ import PIL.Image
15
+ import PIL.ImageFont
16
+ import scipy.ndimage
17
+
18
+ from . import gl_utils
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def get_default_font():
23
+ url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
24
+ return dnnlib.util.open_url(url, return_filename=True)
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ @functools.lru_cache(maxsize=None)
29
+ def get_pil_font(font=None, size=32):
30
+ if font is None:
31
+ font = get_default_font()
32
+ return PIL.ImageFont.truetype(font=font, size=size)
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def get_array(string, *, dropshadow_radius: int=None, **kwargs):
37
+ if dropshadow_radius is not None:
38
+ offset_x = int(np.ceil(dropshadow_radius*2/3))
39
+ offset_y = int(np.ceil(dropshadow_radius*2/3))
40
+ return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
41
+ else:
42
+ return _get_array_priv(string, **kwargs)
43
+
44
+ @functools.lru_cache(maxsize=10000)
45
+ def _get_array_priv(
46
+ string: str, *,
47
+ size: int = 32,
48
+ max_width: Optional[int]=None,
49
+ max_height: Optional[int]=None,
50
+ min_size=10,
51
+ shrink_coef=0.8,
52
+ dropshadow_radius: int=None,
53
+ offset_x: int=None,
54
+ offset_y: int=None,
55
+ **kwargs
56
+ ):
57
+ cur_size = size
58
+ array = None
59
+ while True:
60
+ if dropshadow_radius is not None:
61
+ # separate implementation for dropshadow text rendering
62
+ array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
63
+ else:
64
+ array = _get_array_impl(string, size=cur_size, **kwargs)
65
+ height, width, _ = array.shape
66
+ if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
67
+ break
68
+ cur_size = max(int(cur_size * shrink_coef), min_size)
69
+ return array
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ @functools.lru_cache(maxsize=10000)
74
+ def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
75
+ pil_font = get_pil_font(font=font, size=size)
76
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
77
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
78
+ width = max(line.shape[1] for line in lines)
79
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
80
+ line_spacing = line_pad if line_pad is not None else size // 2
81
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
82
+ mask = np.concatenate(lines, axis=0)
83
+ alpha = mask
84
+ if outline > 0:
85
+ mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
86
+ alpha = mask.astype(np.float32) / 255
87
+ alpha = scipy.ndimage.gaussian_filter(alpha, outline)
88
+ alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
89
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
90
+ alpha = np.maximum(alpha, mask)
91
+ return np.stack([mask, alpha], axis=-1)
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+ @functools.lru_cache(maxsize=10000)
96
+ def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
97
+ assert (offset_x > 0) and (offset_y > 0)
98
+ pil_font = get_pil_font(font=font, size=size)
99
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
100
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
101
+ width = max(line.shape[1] for line in lines)
102
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
103
+ line_spacing = line_pad if line_pad is not None else size // 2
104
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
105
+ mask = np.concatenate(lines, axis=0)
106
+ alpha = mask
107
+
108
+ mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
109
+ alpha = mask.astype(np.float32) / 255
110
+ alpha = scipy.ndimage.gaussian_filter(alpha, radius)
111
+ alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
112
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
113
+ alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
114
+ alpha = np.maximum(alpha, mask)
115
+ return np.stack([mask, alpha], axis=-1)
116
+
117
+ #----------------------------------------------------------------------------
118
+
119
+ @functools.lru_cache(maxsize=10000)
120
+ def get_texture(string, bilinear=True, mipmap=True, **kwargs):
121
+ return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
122
+
123
+ #----------------------------------------------------------------------------
legacy.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Converting legacy network pickle into the new format."""
10
+
11
+ import click
12
+ import pickle
13
+ import re
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ from torch_utils import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def load_network_pkl(f, force_fp16=False):
23
+ data = _LegacyUnpickler(f).load()
24
+
25
+ # Legacy TensorFlow pickle => convert.
26
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
27
+ tf_G, tf_D, tf_Gs = data
28
+ G = convert_tf_generator(tf_G)
29
+ D = convert_tf_discriminator(tf_D)
30
+ G_ema = convert_tf_generator(tf_Gs)
31
+ data = dict(G=G, D=D, G_ema=G_ema)
32
+
33
+ # Add missing fields.
34
+ if 'training_set_kwargs' not in data:
35
+ data['training_set_kwargs'] = None
36
+ if 'augment_pipe' not in data:
37
+ data['augment_pipe'] = None
38
+
39
+ # Validate contents.
40
+ assert isinstance(data['G'], torch.nn.Module)
41
+ assert isinstance(data['D'], torch.nn.Module)
42
+ assert isinstance(data['G_ema'], torch.nn.Module)
43
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
44
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
45
+
46
+ # Force FP16.
47
+ if force_fp16:
48
+ for key in ['G', 'D', 'G_ema']:
49
+ old = data[key]
50
+ kwargs = copy.deepcopy(old.init_kwargs)
51
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
52
+ fp16_kwargs.num_fp16_res = 4
53
+ fp16_kwargs.conv_clamp = 256
54
+ if kwargs != old.init_kwargs:
55
+ new = type(old)(**kwargs).eval().requires_grad_(False)
56
+ misc.copy_params_and_buffers(old, new, require_all=True)
57
+ data[key] = new
58
+ return data
59
+
60
+ #----------------------------------------------------------------------------
61
+
62
+ class _TFNetworkStub(dnnlib.EasyDict):
63
+ pass
64
+
65
+ class _LegacyUnpickler(pickle.Unpickler):
66
+ def find_class(self, module, name):
67
+ if module == 'dnnlib.tflib.network' and name == 'Network':
68
+ return _TFNetworkStub
69
+ return super().find_class(module, name)
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ def _collect_tf_params(tf_net):
74
+ # pylint: disable=protected-access
75
+ tf_params = dict()
76
+ def recurse(prefix, tf_net):
77
+ for name, value in tf_net.variables:
78
+ tf_params[prefix + name] = value
79
+ for name, comp in tf_net.components.items():
80
+ recurse(prefix + name + '/', comp)
81
+ recurse('', tf_net)
82
+ return tf_params
83
+
84
+ #----------------------------------------------------------------------------
85
+
86
+ def _populate_module_params(module, *patterns):
87
+ for name, tensor in misc.named_params_and_buffers(module):
88
+ found = False
89
+ value = None
90
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
91
+ match = re.fullmatch(pattern, name)
92
+ if match:
93
+ found = True
94
+ if value_fn is not None:
95
+ value = value_fn(*match.groups())
96
+ break
97
+ try:
98
+ assert found
99
+ if value is not None:
100
+ tensor.copy_(torch.from_numpy(np.array(value)))
101
+ except:
102
+ print(name, list(tensor.shape))
103
+ raise
104
+
105
+ #----------------------------------------------------------------------------
106
+
107
+ def convert_tf_generator(tf_G):
108
+ if tf_G.version < 4:
109
+ raise ValueError('TensorFlow pickle version too low')
110
+
111
+ # Collect kwargs.
112
+ tf_kwargs = tf_G.static_kwargs
113
+ known_kwargs = set()
114
+ def kwarg(tf_name, default=None, none=None):
115
+ known_kwargs.add(tf_name)
116
+ val = tf_kwargs.get(tf_name, default)
117
+ return val if val is not None else none
118
+
119
+ # Convert kwargs.
120
+ from training import networks_stylegan2
121
+ network_class = networks_stylegan2.Generator
122
+ kwargs = dnnlib.EasyDict(
123
+ z_dim = kwarg('latent_size', 512),
124
+ c_dim = kwarg('label_size', 0),
125
+ w_dim = kwarg('dlatent_size', 512),
126
+ img_resolution = kwarg('resolution', 1024),
127
+ img_channels = kwarg('num_channels', 3),
128
+ channel_base = kwarg('fmap_base', 16384) * 2,
129
+ channel_max = kwarg('fmap_max', 512),
130
+ num_fp16_res = kwarg('num_fp16_res', 0),
131
+ conv_clamp = kwarg('conv_clamp', None),
132
+ architecture = kwarg('architecture', 'skip'),
133
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
134
+ use_noise = kwarg('use_noise', True),
135
+ activation = kwarg('nonlinearity', 'lrelu'),
136
+ mapping_kwargs = dnnlib.EasyDict(
137
+ num_layers = kwarg('mapping_layers', 8),
138
+ embed_features = kwarg('label_fmaps', None),
139
+ layer_features = kwarg('mapping_fmaps', None),
140
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
141
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
142
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
143
+ ),
144
+ )
145
+
146
+ # Check for unknown kwargs.
147
+ kwarg('truncation_psi')
148
+ kwarg('truncation_cutoff')
149
+ kwarg('style_mixing_prob')
150
+ kwarg('structure')
151
+ kwarg('conditioning')
152
+ kwarg('fused_modconv')
153
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
+ if len(unknown_kwargs) > 0:
155
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
+
157
+ # Collect params.
158
+ tf_params = _collect_tf_params(tf_G)
159
+ for name, value in list(tf_params.items()):
160
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
+ if match:
162
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
+ kwargs.synthesis.kwargs.architecture = 'orig'
165
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
+
167
+ # Convert params.
168
+ G = network_class(**kwargs).eval().requires_grad_(False)
169
+ # pylint: disable=unnecessary-lambda
170
+ # pylint: disable=f-string-without-interpolation
171
+ _populate_module_params(G,
172
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
+ r'.*\.resample_filter', None,
202
+ r'.*\.act_filter', None,
203
+ )
204
+ return G
205
+
206
+ #----------------------------------------------------------------------------
207
+
208
+ def convert_tf_discriminator(tf_D):
209
+ if tf_D.version < 4:
210
+ raise ValueError('TensorFlow pickle version too low')
211
+
212
+ # Collect kwargs.
213
+ tf_kwargs = tf_D.static_kwargs
214
+ known_kwargs = set()
215
+ def kwarg(tf_name, default=None):
216
+ known_kwargs.add(tf_name)
217
+ return tf_kwargs.get(tf_name, default)
218
+
219
+ # Convert kwargs.
220
+ kwargs = dnnlib.EasyDict(
221
+ c_dim = kwarg('label_size', 0),
222
+ img_resolution = kwarg('resolution', 1024),
223
+ img_channels = kwarg('num_channels', 3),
224
+ architecture = kwarg('architecture', 'resnet'),
225
+ channel_base = kwarg('fmap_base', 16384) * 2,
226
+ channel_max = kwarg('fmap_max', 512),
227
+ num_fp16_res = kwarg('num_fp16_res', 0),
228
+ conv_clamp = kwarg('conv_clamp', None),
229
+ cmap_dim = kwarg('mapping_fmaps', None),
230
+ block_kwargs = dnnlib.EasyDict(
231
+ activation = kwarg('nonlinearity', 'lrelu'),
232
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
233
+ freeze_layers = kwarg('freeze_layers', 0),
234
+ ),
235
+ mapping_kwargs = dnnlib.EasyDict(
236
+ num_layers = kwarg('mapping_layers', 0),
237
+ embed_features = kwarg('mapping_fmaps', None),
238
+ layer_features = kwarg('mapping_fmaps', None),
239
+ activation = kwarg('nonlinearity', 'lrelu'),
240
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
241
+ ),
242
+ epilogue_kwargs = dnnlib.EasyDict(
243
+ mbstd_group_size = kwarg('mbstd_group_size', None),
244
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
245
+ activation = kwarg('nonlinearity', 'lrelu'),
246
+ ),
247
+ )
248
+
249
+ # Check for unknown kwargs.
250
+ kwarg('structure')
251
+ kwarg('conditioning')
252
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
253
+ if len(unknown_kwargs) > 0:
254
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
255
+
256
+ # Collect params.
257
+ tf_params = _collect_tf_params(tf_D)
258
+ for name, value in list(tf_params.items()):
259
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
260
+ if match:
261
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
262
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
263
+ kwargs.architecture = 'orig'
264
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
265
+
266
+ # Convert params.
267
+ from training import networks_stylegan2
268
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
269
+ # pylint: disable=unnecessary-lambda
270
+ # pylint: disable=f-string-without-interpolation
271
+ _populate_module_params(D,
272
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
273
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
274
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
275
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
276
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
277
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
278
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
279
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
280
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
281
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
282
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
283
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
284
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
285
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
286
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
287
+ r'.*\.resample_filter', None,
288
+ )
289
+ return D
290
+
291
+ #----------------------------------------------------------------------------
292
+
293
+ @click.command()
294
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
295
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
296
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
297
+ def convert_network_pickle(source, dest, force_fp16):
298
+ """Convert legacy network pickle into the native PyTorch format.
299
+
300
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
301
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
302
+
303
+ Example:
304
+
305
+ \b
306
+ python legacy.py \\
307
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
308
+ --dest=stylegan2-cat-config-f.pkl
309
+ """
310
+ print(f'Loading "{source}"...')
311
+ with dnnlib.util.open_url(source) as f:
312
+ data = load_network_pkl(f, force_fp16=force_fp16)
313
+ print(f'Saving "{dest}"...')
314
+ with open(dest, 'wb') as f:
315
+ pickle.dump(data, f)
316
+ print('Done.')
317
+
318
+ #----------------------------------------------------------------------------
319
+
320
+ if __name__ == "__main__":
321
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
322
+
323
+ #----------------------------------------------------------------------------
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ scipy==1.11.0
3
+ Ninja==1.10.2
4
+ gradio>=3.35.2
5
+ imageio-ffmpeg>=0.4.3
6
+ huggingface_hub
7
+ hf_transfer
8
+ pyopengl
9
+ imgui
10
+ glfw==2.6.1
11
+ pillow>=9.4.0
12
+ torchvision>=0.15.2
13
+ imageio>=2.9.0
scripts/download_model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import requests
5
+ from tqdm import tqdm
6
+
7
+ def download_file(url: str, filename: str, download_dir: str):
8
+ """Download a file if it does not already exist."""
9
+
10
+ try:
11
+ filepath = os.path.join(download_dir, filename)
12
+ content_length = int(requests.head(url).headers.get("content-length", 0))
13
+
14
+ # If file already exists and size matches, skip download
15
+ if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length:
16
+ print(f"{filepath} already exists. Skipping download.")
17
+ return
18
+ if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length:
19
+ print(f"{filepath} already exists but size does not match. Redownloading.")
20
+ else:
21
+ print(f"Downloading {filename} from {url}")
22
+
23
+ # Start download, stream=True allows for progress tracking
24
+ response = requests.get(url, stream=True)
25
+
26
+ # Check if request was successful
27
+ response.raise_for_status()
28
+
29
+ # Create progress bar
30
+ total_size = int(response.headers.get('content-length', 0))
31
+ progress_bar = tqdm(
32
+ total=total_size,
33
+ unit='iB',
34
+ unit_scale=True,
35
+ ncols=70,
36
+ file=sys.stdout
37
+ )
38
+
39
+ # Write response content to file
40
+ with open(filepath, 'wb') as f:
41
+ for data in response.iter_content(chunk_size=1024):
42
+ f.write(data)
43
+ progress_bar.update(len(data)) # Update progress bar
44
+
45
+ # Close progress bar
46
+ progress_bar.close()
47
+
48
+ # Error handling for incomplete downloads
49
+ if total_size != 0 and progress_bar.n != total_size:
50
+ print("ERROR, something went wrong while downloading")
51
+ raise Exception()
52
+
53
+
54
+ except Exception as e:
55
+ print(f"An error occurred: {e}")
56
+
57
+ def main():
58
+ """Main function to download files from URLs in a config file."""
59
+
60
+ # Get JSON config file path
61
+ script_dir = os.path.dirname(os.path.realpath(__file__))
62
+ config_file_path = os.path.join(script_dir, "download_models.json")
63
+
64
+ # Set download directory
65
+ download_dir = "checkpoints"
66
+ os.makedirs(download_dir, exist_ok=True)
67
+
68
+ # Load URL and filenames from JSON
69
+ with open(config_file_path, "r") as f:
70
+ config = json.load(f)
71
+
72
+ # Download each file specified in config
73
+ for url, filename in config.items():
74
+ download_file(url, filename, download_dir)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
scripts/download_models.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl": "stylegan2_lions_512_pytorch.pkl",
3
+ "https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl": "stylegan2_dogs_1024_pytorch.pkl",
4
+ "https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl": "stylegan2_horses_256_pytorch.pkl",
5
+ "https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl": "stylegan2_elephants_512_pytorch.pkl",
6
+ "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl": "stylegan2-ffhq-512x512.pkl",
7
+ "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl": "stylegan2-afhqcat-512x512.pkl",
8
+ "http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl": "stylegan2-car-config-f.pkl",
9
+ "http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl": "stylegan2-cat-config-f.pkl"
10
+ }
scripts/gui.bat ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ python visualizer_drag.py ^
3
+ checkpoints/stylegan2_lions_512_pytorch.pkl ^
4
+ checkpoints/stylegan2-ffhq-512x512.pkl ^
5
+ checkpoints/stylegan2-afhqcat-512x512.pkl ^
6
+ checkpoints/stylegan2-car-config-f.pkl ^
7
+ checkpoints/stylegan2_dogs_1024_pytorch.pkl ^
8
+ checkpoints/stylegan2_horses_256_pytorch.pkl ^
9
+ checkpoints/stylegan2-cat-config-f.pkl ^
10
+ checkpoints/stylegan2_elephants_512_pytorch.pkl ^
11
+ checkpoints/stylegan_human_v2_512.pkl ^
12
+ checkpoints/stylegan2-lhq-256x256.pkl
scripts/gui.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python visualizer_drag.py \
2
+ checkpoints/stylegan2_lions_512_pytorch.pkl \
3
+ checkpoints/stylegan2-ffhq-512x512.pkl \
4
+ checkpoints/stylegan2-afhqcat-512x512.pkl \
5
+ checkpoints/stylegan2-car-config-f.pkl \
6
+ checkpoints/stylegan2_dogs_1024_pytorch.pkl \
7
+ checkpoints/stylegan2_horses_256_pytorch.pkl \
8
+ checkpoints/stylegan2-cat-config-f.pkl \
9
+ checkpoints/stylegan2_elephants_512_pytorch.pkl \
10
+ checkpoints/stylegan_human_v2_512.pkl \
11
+ checkpoints/stylegan2-lhq-256x256.pkl
stylegan_human/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__
3
+ *.pt
4
+ *.pth
5
+ *.pdparams
6
+ *.pdiparams
7
+ *.pdmodel
8
+ *.pkl
9
+ *.info
10
+ *.yaml
stylegan_human/PP_HumanSeg/deploy/infer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+
4
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import codecs
19
+ import os
20
+ import time
21
+
22
+ import yaml
23
+ import numpy as np
24
+ import cv2
25
+ import paddle
26
+ import paddleseg.transforms as T
27
+ from paddle.inference import create_predictor, PrecisionType
28
+ from paddle.inference import Config as PredictConfig
29
+ from paddleseg.core.infer import reverse_transform
30
+ from paddleseg.cvlibs import manager
31
+ from paddleseg.utils import TimeAverager
32
+
33
+ from ..scripts.optic_flow_process import optic_flow_process
34
+
35
+
36
+ class DeployConfig:
37
+ def __init__(self, path):
38
+ with codecs.open(path, 'r', 'utf-8') as file:
39
+ self.dic = yaml.load(file, Loader=yaml.FullLoader)
40
+
41
+ self._transforms = self._load_transforms(self.dic['Deploy'][
42
+ 'transforms'])
43
+ self._dir = os.path.dirname(path)
44
+
45
+ @property
46
+ def transforms(self):
47
+ return self._transforms
48
+
49
+ @property
50
+ def model(self):
51
+ return os.path.join(self._dir, self.dic['Deploy']['model'])
52
+
53
+ @property
54
+ def params(self):
55
+ return os.path.join(self._dir, self.dic['Deploy']['params'])
56
+
57
+ def _load_transforms(self, t_list):
58
+ com = manager.TRANSFORMS
59
+ transforms = []
60
+ for t in t_list:
61
+ ctype = t.pop('type')
62
+ transforms.append(com[ctype](**t))
63
+
64
+ return transforms
65
+
66
+
67
+ class Predictor:
68
+ def __init__(self, args):
69
+ self.cfg = DeployConfig(args.cfg)
70
+ self.args = args
71
+ self.compose = T.Compose(self.cfg.transforms)
72
+ resize_h, resize_w = args.input_shape
73
+
74
+ self.disflow = cv2.DISOpticalFlow_create(
75
+ cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
76
+ self.prev_gray = np.zeros((resize_h, resize_w), np.uint8)
77
+ self.prev_cfd = np.zeros((resize_h, resize_w), np.float32)
78
+ self.is_init = True
79
+
80
+ pred_cfg = PredictConfig(self.cfg.model, self.cfg.params)
81
+ pred_cfg.disable_glog_info()
82
+ if self.args.use_gpu:
83
+ pred_cfg.enable_use_gpu(100, 0)
84
+
85
+ self.predictor = create_predictor(pred_cfg)
86
+ if self.args.test_speed:
87
+ self.cost_averager = TimeAverager()
88
+
89
+ def preprocess(self, img):
90
+ ori_shapes = []
91
+ processed_imgs = []
92
+ processed_img = self.compose(img)[0]
93
+ processed_imgs.append(processed_img)
94
+ ori_shapes.append(img.shape)
95
+ return processed_imgs, ori_shapes
96
+
97
+ def run(self, img, bg):
98
+ input_names = self.predictor.get_input_names()
99
+ input_handle = self.predictor.get_input_handle(input_names[0])
100
+ processed_imgs, ori_shapes = self.preprocess(img)
101
+ data = np.array(processed_imgs)
102
+ input_handle.reshape(data.shape)
103
+ input_handle.copy_from_cpu(data)
104
+ if self.args.test_speed:
105
+ start = time.time()
106
+
107
+ self.predictor.run()
108
+
109
+ if self.args.test_speed:
110
+ self.cost_averager.record(time.time() - start)
111
+ output_names = self.predictor.get_output_names()
112
+ output_handle = self.predictor.get_output_handle(output_names[0])
113
+ output = output_handle.copy_to_cpu()
114
+ return self.postprocess(output, img, ori_shapes[0], bg)
115
+
116
+
117
+ def postprocess(self, pred, img, ori_shape, bg):
118
+ if not os.path.exists(self.args.save_dir):
119
+ os.makedirs(self.args.save_dir)
120
+ resize_w = pred.shape[-1]
121
+ resize_h = pred.shape[-2]
122
+ if self.args.soft_predict:
123
+ if self.args.use_optic_flow:
124
+ score_map = pred[:, 1, :, :].squeeze(0)
125
+ score_map = 255 * score_map
126
+ cur_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
127
+ cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
128
+ optflow_map = optic_flow_process(cur_gray, score_map, self.prev_gray, self.prev_cfd, \
129
+ self.disflow, self.is_init)
130
+ self.prev_gray = cur_gray.copy()
131
+ self.prev_cfd = optflow_map.copy()
132
+ self.is_init = False
133
+
134
+ score_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
135
+ score_map = np.transpose(score_map, [2, 0, 1])[np.newaxis, ...]
136
+ score_map = reverse_transform(
137
+ paddle.to_tensor(score_map),
138
+ ori_shape,
139
+ self.cfg.transforms,
140
+ mode='bilinear')
141
+ alpha = np.transpose(score_map.numpy().squeeze(0),
142
+ [1, 2, 0]) / 255
143
+ else:
144
+ score_map = pred[:, 1, :, :]
145
+ score_map = score_map[np.newaxis, ...]
146
+ score_map = reverse_transform(
147
+ paddle.to_tensor(score_map),
148
+ ori_shape,
149
+ self.cfg.transforms,
150
+ mode='bilinear')
151
+ alpha = np.transpose(score_map.numpy().squeeze(0), [1, 2, 0])
152
+
153
+ else:
154
+ if pred.ndim == 3:
155
+ pred = pred[:, np.newaxis, ...]
156
+ result = reverse_transform(
157
+ paddle.to_tensor(
158
+ pred, dtype='float32'),
159
+ ori_shape,
160
+ self.cfg.transforms,
161
+ mode='bilinear')
162
+
163
+ result = np.array(result)
164
+ if self.args.add_argmax:
165
+ result = np.argmax(result, axis=1)
166
+ else:
167
+ result = result.squeeze(1)
168
+ alpha = np.transpose(result, [1, 2, 0])
169
+
170
+ # background replace
171
+ h, w, _ = img.shape
172
+ if bg is None:
173
+ bg = np.ones_like(img)*255
174
+ else:
175
+ bg = cv2.resize(bg, (w, h))
176
+ if bg.ndim == 2:
177
+ bg = bg[..., np.newaxis]
178
+
179
+ comb = (alpha * img + (1 - alpha) * bg).astype(np.uint8)
180
+ return comb, alpha, bg, img
stylegan_human/PP_HumanSeg/export_model/download_export_model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf8
2
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import sys
17
+ import os
18
+
19
+ LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
20
+ TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
21
+ sys.path.append(TEST_PATH)
22
+
23
+ from paddleseg.utils.download import download_file_and_uncompress
24
+
25
+ model_urls = {
26
+ "pphumanseg_lite_portrait_398x224_with_softmax":
27
+ "https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz",
28
+ "deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax":
29
+ "https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax.zip",
30
+ "fcn_hrnetw18_small_v1_humanseg_192x192_with_softmax":
31
+ "https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/fcn_hrnetw18_small_v1_humanseg_192x192_with_softmax.zip",
32
+ "pphumanseg_lite_generic_humanseg_192x192_with_softmax":
33
+ "https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/pphumanseg_lite_generic_192x192_with_softmax.zip",
34
+ }
35
+
36
+ if __name__ == "__main__":
37
+ for model_name, url in model_urls.items():
38
+ download_file_and_uncompress(
39
+ url=url,
40
+ savepath=LOCAL_PATH,
41
+ extrapath=LOCAL_PATH,
42
+ extraname=model_name)
43
+
44
+ print("Export model download success!")
stylegan_human/PP_HumanSeg/pretrained_model/download_pretrained_model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf8
2
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import sys
17
+ import os
18
+
19
+ LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
20
+ TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
21
+ sys.path.append(TEST_PATH)
22
+
23
+ from paddleseg.utils.download import download_file_and_uncompress
24
+
25
+ model_urls = {
26
+ "pphumanseg_lite_portrait_398x224":
27
+ "https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224.tar.gz",
28
+ "deeplabv3p_resnet50_os8_humanseg_512x512_100k":
29
+ "https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/deeplabv3p_resnet50_os8_humanseg_512x512_100k.zip",
30
+ "fcn_hrnetw18_small_v1_humanseg_192x192":
31
+ "https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/fcn_hrnetw18_small_v1_humanseg_192x192.zip",
32
+ "pphumanseg_lite_generic_human_192x192":
33
+ "https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/pphumanseg_lite_generic_192x192.zip",
34
+ }
35
+
36
+ if __name__ == "__main__":
37
+ for model_name, url in model_urls.items():
38
+ download_file_and_uncompress(
39
+ url=url,
40
+ savepath=LOCAL_PATH,
41
+ extrapath=LOCAL_PATH,
42
+ extraname=model_name)
43
+
44
+ print("Pretrained model download success!")