Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .gitignore +157 -0
- Dockerfile +28 -0
- DragGAN.gif +3 -0
- LICENSE.txt +97 -0
- README.md +140 -8
- __pycache__/legacy.cpython-39.pyc +0 -0
- arial.ttf +0 -0
- checkpoints/stylegan2-afhqcat-512x512.pkl +3 -0
- checkpoints/stylegan2-car-config-f.pkl +3 -0
- checkpoints/stylegan2-cat-config-f.pkl +3 -0
- checkpoints/stylegan2-ffhq-512x512.pkl +3 -0
- checkpoints/stylegan2-lhq-256x256.pkl +3 -0
- checkpoints/stylegan2_dogs_1024_pytorch.pkl +3 -0
- checkpoints/stylegan2_elephants_512_pytorch.pkl +3 -0
- checkpoints/stylegan2_horses_256_pytorch.pkl +3 -0
- checkpoints/stylegan2_lions_512_pytorch.pkl +3 -0
- checkpoints/stylegan_human_v2_512.pkl +3 -0
- dnnlib/__init__.py +9 -0
- dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
- dnnlib/__pycache__/util.cpython-39.pyc +0 -0
- dnnlib/util.py +491 -0
- environment.yml +27 -0
- gen_images.py +150 -0
- gradio_utils/__init__.py +9 -0
- gradio_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- gradio_utils/__pycache__/utils.cpython-39.pyc +0 -0
- gradio_utils/utils.py +154 -0
- gui_utils/__init__.py +9 -0
- gui_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- gui_utils/__pycache__/gl_utils.cpython-39.pyc +0 -0
- gui_utils/__pycache__/glfw_window.cpython-39.pyc +0 -0
- gui_utils/__pycache__/imgui_utils.cpython-39.pyc +0 -0
- gui_utils/__pycache__/imgui_window.cpython-39.pyc +0 -0
- gui_utils/__pycache__/text_utils.cpython-39.pyc +0 -0
- gui_utils/gl_utils.py +416 -0
- gui_utils/glfw_window.py +229 -0
- gui_utils/imgui_utils.py +191 -0
- gui_utils/imgui_window.py +103 -0
- gui_utils/text_utils.py +123 -0
- legacy.py +323 -0
- requirements.txt +13 -0
- scripts/download_model.py +78 -0
- scripts/download_models.json +10 -0
- scripts/gui.bat +12 -0
- scripts/gui.sh +11 -0
- stylegan_human/.gitignore +10 -0
- stylegan_human/PP_HumanSeg/deploy/infer.py +180 -0
- stylegan_human/PP_HumanSeg/export_model/download_export_model.py +44 -0
- stylegan_human/PP_HumanSeg/pretrained_model/download_pretrained_model.py +44 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
DragGAN.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
stylegan_human/img/demo_V5_thumbnails-min.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
stylegan_human/img/preview_samples1.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
stylegan_human/img/test/test.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by .ignore support plugin (hsz.mobi)
|
2 |
+
### Python template
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
env/
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.coverage
|
43 |
+
.coverage.*
|
44 |
+
.cache
|
45 |
+
nosetests.xml
|
46 |
+
coverage.xml
|
47 |
+
*,cover
|
48 |
+
.hypothesis/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
|
58 |
+
# Flask stuff:
|
59 |
+
instance/
|
60 |
+
.webassets-cache
|
61 |
+
|
62 |
+
# Scrapy stuff:
|
63 |
+
.scrapy
|
64 |
+
|
65 |
+
# Sphinx documentation
|
66 |
+
docs/_build/
|
67 |
+
|
68 |
+
# PyBuilder
|
69 |
+
target/
|
70 |
+
|
71 |
+
# IPython Notebook
|
72 |
+
.ipynb_checkpoints
|
73 |
+
|
74 |
+
# pyenv
|
75 |
+
.python-version
|
76 |
+
|
77 |
+
# celery beat schedule file
|
78 |
+
celerybeat-schedule
|
79 |
+
|
80 |
+
# dotenv
|
81 |
+
.env
|
82 |
+
|
83 |
+
# virtualenv
|
84 |
+
venv/
|
85 |
+
ENV/
|
86 |
+
|
87 |
+
# Spyder project settings
|
88 |
+
.spyderproject
|
89 |
+
|
90 |
+
# Rope project settings
|
91 |
+
.ropeproject
|
92 |
+
### VirtualEnv template
|
93 |
+
# Virtualenv
|
94 |
+
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
95 |
+
.Python
|
96 |
+
[Bb]in
|
97 |
+
[Ii]nclude
|
98 |
+
[Ll]ib
|
99 |
+
[Ll]ib64
|
100 |
+
[Ll]ocal
|
101 |
+
[Ss]cripts
|
102 |
+
!scripts/
|
103 |
+
pyvenv.cfg
|
104 |
+
.venv
|
105 |
+
pip-selfcheck.json
|
106 |
+
### JetBrains template
|
107 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
|
108 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
109 |
+
|
110 |
+
# User-specific stuff:
|
111 |
+
.idea/workspace.xml
|
112 |
+
.idea/tasks.xml
|
113 |
+
.idea/dictionaries
|
114 |
+
.idea/vcs.xml
|
115 |
+
.idea/jsLibraryMappings.xml
|
116 |
+
|
117 |
+
# Sensitive or high-churn files:
|
118 |
+
.idea/dataSources.ids
|
119 |
+
.idea/dataSources.xml
|
120 |
+
.idea/dataSources.local.xml
|
121 |
+
.idea/sqlDataSources.xml
|
122 |
+
.idea/dynamic.xml
|
123 |
+
.idea/uiDesigner.xml
|
124 |
+
|
125 |
+
# Gradle:
|
126 |
+
.idea/gradle.xml
|
127 |
+
.idea/libraries
|
128 |
+
|
129 |
+
# Mongo Explorer plugin:
|
130 |
+
.idea/mongoSettings.xml
|
131 |
+
|
132 |
+
.idea/
|
133 |
+
|
134 |
+
## File-based project format:
|
135 |
+
*.iws
|
136 |
+
|
137 |
+
## Plugin-specific files:
|
138 |
+
|
139 |
+
# IntelliJ
|
140 |
+
/out/
|
141 |
+
|
142 |
+
# mpeltonen/sbt-idea plugin
|
143 |
+
.idea_modules/
|
144 |
+
|
145 |
+
# JIRA plugin
|
146 |
+
atlassian-ide-plugin.xml
|
147 |
+
|
148 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
149 |
+
com_crashlytics_export_strings.xml
|
150 |
+
crashlytics.properties
|
151 |
+
crashlytics-build.properties
|
152 |
+
fabric.properties
|
153 |
+
|
154 |
+
# Mac related
|
155 |
+
.DS_Store
|
156 |
+
|
157 |
+
checkpoints
|
Dockerfile
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/pytorch:23.05-py3
|
2 |
+
|
3 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
4 |
+
ENV PYTHONUNBUFFERED 1
|
5 |
+
|
6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
7 |
+
make \
|
8 |
+
pkgconf \
|
9 |
+
xz-utils \
|
10 |
+
xorg-dev \
|
11 |
+
libgl1-mesa-dev \
|
12 |
+
libglu1-mesa-dev \
|
13 |
+
libxrandr-dev \
|
14 |
+
libxinerama-dev \
|
15 |
+
libxcursor-dev \
|
16 |
+
libxi-dev \
|
17 |
+
libxxf86vm-dev \
|
18 |
+
&& rm -rf /var/lib/apt/lists/*
|
19 |
+
|
20 |
+
RUN pip install --no-cache-dir --upgrade pip
|
21 |
+
|
22 |
+
COPY requirements.txt .
|
23 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
24 |
+
|
25 |
+
WORKDIR /workspace
|
26 |
+
|
27 |
+
RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
|
28 |
+
ENTRYPOINT ["/entry.sh"]
|
DragGAN.gif
ADDED
![]() |
Git LFS Details
|
LICENSE.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
NVIDIA Source Code License for StyleGAN3
|
5 |
+
|
6 |
+
|
7 |
+
=======================================================================
|
8 |
+
|
9 |
+
1. Definitions
|
10 |
+
|
11 |
+
"Licensor" means any person or entity that distributes its Work.
|
12 |
+
|
13 |
+
"Software" means the original work of authorship made available under
|
14 |
+
this License.
|
15 |
+
|
16 |
+
"Work" means the Software and any additions to or derivative works of
|
17 |
+
the Software that are made available under this License.
|
18 |
+
|
19 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
20 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
21 |
+
provided, however, that for the purposes of this License, derivative
|
22 |
+
works shall not include works that remain separable from, or merely
|
23 |
+
link (or bind by name) to the interfaces of, the Work.
|
24 |
+
|
25 |
+
Works, including the Software, are "made available" under this License
|
26 |
+
by including in or with the Work either (a) a copyright notice
|
27 |
+
referencing the applicability of this License to the Work, or (b) a
|
28 |
+
copy of this License.
|
29 |
+
|
30 |
+
2. License Grants
|
31 |
+
|
32 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
33 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
34 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
35 |
+
prepare derivative works of, publicly display, publicly perform,
|
36 |
+
sublicense and distribute its Work and any resulting derivative
|
37 |
+
works in any form.
|
38 |
+
|
39 |
+
3. Limitations
|
40 |
+
|
41 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
42 |
+
if (a) you do so under this License, (b) you include a complete
|
43 |
+
copy of this License with your distribution, and (c) you retain
|
44 |
+
without modification any copyright, patent, trademark, or
|
45 |
+
attribution notices that are present in the Work.
|
46 |
+
|
47 |
+
3.2 Derivative Works. You may specify that additional or different
|
48 |
+
terms apply to the use, reproduction, and distribution of your
|
49 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
50 |
+
provide that the use limitation in Section 3.3 applies to your
|
51 |
+
derivative works, and (b) you identify the specific derivative
|
52 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
53 |
+
this License (including the redistribution requirements in Section
|
54 |
+
3.1) will continue to apply to the Work itself.
|
55 |
+
|
56 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
57 |
+
may be used or intended for use non-commercially. Notwithstanding
|
58 |
+
the foregoing, NVIDIA and its affiliates may use the Work and any
|
59 |
+
derivative works commercially. As used herein, "non-commercially"
|
60 |
+
means for research or evaluation purposes only.
|
61 |
+
|
62 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
63 |
+
against any Licensor (including any claim, cross-claim or
|
64 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
65 |
+
are infringed by any Work, then your rights under this License from
|
66 |
+
such Licensor (including the grant in Section 2.1) will terminate
|
67 |
+
immediately.
|
68 |
+
|
69 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
70 |
+
Licensor’s or its affiliates’ names, logos, or trademarks, except
|
71 |
+
as necessary to reproduce the notices described in this License.
|
72 |
+
|
73 |
+
3.6 Termination. If you violate any term of this License, then your
|
74 |
+
rights under this License (including the grant in Section 2.1) will
|
75 |
+
terminate immediately.
|
76 |
+
|
77 |
+
4. Disclaimer of Warranty.
|
78 |
+
|
79 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
80 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
81 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
82 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
83 |
+
THIS LICENSE.
|
84 |
+
|
85 |
+
5. Limitation of Liability.
|
86 |
+
|
87 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
88 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
89 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
90 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
91 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
92 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
93 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
94 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
95 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
96 |
+
|
97 |
+
=======================================================================
|
README.md
CHANGED
@@ -1,12 +1,144 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: shuangzizuoDragGAN
|
3 |
+
app_file: visualizer_drag_gradio.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 3.35.2
|
|
|
|
|
6 |
---
|
7 |
+
<p align="center">
|
8 |
|
9 |
+
<h1 align="center">Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold</h1>
|
10 |
+
<p align="center">
|
11 |
+
<a href="https://xingangpan.github.io/"><strong>Xingang Pan</strong></a>
|
12 |
+
·
|
13 |
+
<a href="https://ayushtewari.com/"><strong>Ayush Tewari</strong></a>
|
14 |
+
·
|
15 |
+
<a href="https://people.mpi-inf.mpg.de/~tleimkue/"><strong>Thomas Leimkühler</strong></a>
|
16 |
+
·
|
17 |
+
<a href="https://lingjie0206.github.io/"><strong>Lingjie Liu</strong></a>
|
18 |
+
·
|
19 |
+
<a href="https://www.meka.page/"><strong>Abhimitra Meka</strong></a>
|
20 |
+
·
|
21 |
+
<a href="http://www.mpi-inf.mpg.de/~theobalt/"><strong>Christian Theobalt</strong></a>
|
22 |
+
</p>
|
23 |
+
<h2 align="center">SIGGRAPH 2023 Conference Proceedings</h2>
|
24 |
+
<div align="center">
|
25 |
+
<img src="DragGAN.gif", width="600">
|
26 |
+
</div>
|
27 |
+
|
28 |
+
<p align="center">
|
29 |
+
<br>
|
30 |
+
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
|
31 |
+
<a href="https://twitter.com/XingangP"><img alt='Twitter' src="https://img.shields.io/twitter/follow/XingangP?label=%40XingangP"></a>
|
32 |
+
<a href="https://arxiv.org/abs/2305.10973">
|
33 |
+
<img src='https://img.shields.io/badge/Paper-PDF-green?style=for-the-badge&logo=adobeacrobatreader&logoWidth=20&logoColor=white&labelColor=66cc00&color=94DD15' alt='Paper PDF'>
|
34 |
+
</a>
|
35 |
+
<a href='https://vcai.mpi-inf.mpg.de/projects/DragGAN/'>
|
36 |
+
<img src='https://img.shields.io/badge/DragGAN-Page-orange?style=for-the-badge&logo=Google%20chrome&logoColor=white&labelColor=D35400' alt='Project Page'></a>
|
37 |
+
<a href="https://colab.research.google.com/drive/1mey-IXPwQC_qSthI5hO-LTX7QL4ivtPh?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
38 |
+
</p>
|
39 |
+
</p>
|
40 |
+
|
41 |
+
## Web Demos
|
42 |
+
|
43 |
+
[](https://openxlab.org.cn/apps/detail/XingangPan/DragGAN)
|
44 |
+
|
45 |
+
<p align="left">
|
46 |
+
<a href="https://huggingface.co/spaces/radames/DragGan"><img alt="Huggingface" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DragGAN-orange"></a>
|
47 |
+
</p>
|
48 |
+
|
49 |
+
## Requirements
|
50 |
+
|
51 |
+
If you have CUDA graphic card, please follow the requirements of [NVlabs/stylegan3](https://github.com/NVlabs/stylegan3#requirements).
|
52 |
+
|
53 |
+
The usual installation steps involve the following commands, they should set up the correct CUDA version and all the python packages
|
54 |
+
|
55 |
+
```
|
56 |
+
conda env create -f environment.yml
|
57 |
+
conda activate stylegan3
|
58 |
+
```
|
59 |
+
|
60 |
+
Then install the additional requirements
|
61 |
+
|
62 |
+
```
|
63 |
+
pip install -r requirements.txt
|
64 |
+
```
|
65 |
+
|
66 |
+
Otherwise (for GPU acceleration on MacOS with Silicon Mac M1/M2, or just CPU) try the following:
|
67 |
+
|
68 |
+
```sh
|
69 |
+
cat environment.yml | \
|
70 |
+
grep -v -E 'nvidia|cuda' > environment-no-nvidia.yml && \
|
71 |
+
conda env create -f environment-no-nvidia.yml
|
72 |
+
conda activate stylegan3
|
73 |
+
|
74 |
+
# On MacOS
|
75 |
+
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
76 |
+
```
|
77 |
+
|
78 |
+
## Run Gradio visualizer in Docker
|
79 |
+
|
80 |
+
Provided docker image is based on NGC PyTorch repository. To quickly try out visualizer in Docker, run the following:
|
81 |
+
|
82 |
+
```sh
|
83 |
+
# before you build the docker container, make sure you have cloned this repo, and downloaded the pretrained model by `python scripts/download_model.py`.
|
84 |
+
docker build . -t draggan:latest
|
85 |
+
docker run -p 7860:7860 -v "$PWD":/workspace/src -it draggan:latest bash
|
86 |
+
# (Use GPU)if you want to utilize your Nvidia gpu to accelerate in docker, please add command tag `--gpus all`, like:
|
87 |
+
# docker run --gpus all -p 7860:7860 -v "$PWD":/workspace/src -it draggan:latest bash
|
88 |
+
|
89 |
+
cd src && python visualizer_drag_gradio.py --listen
|
90 |
+
```
|
91 |
+
Now you can open a shared link from Gradio (printed in the terminal console).
|
92 |
+
Beware the Docker image takes about 25GB of disk space!
|
93 |
+
|
94 |
+
## Download pre-trained StyleGAN2 weights
|
95 |
+
|
96 |
+
To download pre-trained weights, simply run:
|
97 |
+
|
98 |
+
```
|
99 |
+
python scripts/download_model.py
|
100 |
+
```
|
101 |
+
If you want to try StyleGAN-Human and the Landscapes HQ (LHQ) dataset, please download weights from these links: [StyleGAN-Human](https://drive.google.com/file/d/1dlFEHbu-WzQWJl7nBBZYcTyo000H9hVm/view?usp=sharing), [LHQ](https://drive.google.com/file/d/16twEf0T9QINAEoMsWefoWiyhcTd-aiWc/view?usp=sharing), and put them under `./checkpoints`.
|
102 |
+
|
103 |
+
Feel free to try other pretrained StyleGAN.
|
104 |
+
|
105 |
+
## Run DragGAN GUI
|
106 |
+
|
107 |
+
To start the DragGAN GUI, simply run:
|
108 |
+
```sh
|
109 |
+
sh scripts/gui.sh
|
110 |
+
```
|
111 |
+
If you are using windows, you can run:
|
112 |
+
```
|
113 |
+
.\scripts\gui.bat
|
114 |
+
```
|
115 |
+
|
116 |
+
This GUI supports editing GAN-generated images. To edit a real image, you need to first perform GAN inversion using tools like [PTI](https://github.com/danielroich/PTI). Then load the new latent code and model weights to the GUI.
|
117 |
+
|
118 |
+
You can run DragGAN Gradio demo as well, this is universal for both windows and linux:
|
119 |
+
```sh
|
120 |
+
python visualizer_drag_gradio.py
|
121 |
+
```
|
122 |
+
|
123 |
+
## Acknowledgement
|
124 |
+
|
125 |
+
This code is developed based on [StyleGAN3](https://github.com/NVlabs/stylegan3). Part of the code is borrowed from [StyleGAN-Human](https://github.com/stylegan-human/StyleGAN-Human).
|
126 |
+
|
127 |
+
(cheers to the community as well)
|
128 |
+
## License
|
129 |
+
|
130 |
+
The code related to the DragGAN algorithm is licensed under [CC-BY-NC](https://creativecommons.org/licenses/by-nc/4.0/).
|
131 |
+
However, most of this project are available under a separate license terms: all codes used or modified from [StyleGAN3](https://github.com/NVlabs/stylegan3) is under the [Nvidia Source Code License](https://github.com/NVlabs/stylegan3/blob/main/LICENSE.txt).
|
132 |
+
|
133 |
+
Any form of use and derivative of this code must preserve the watermarking functionality showing "AI Generated".
|
134 |
+
|
135 |
+
## BibTeX
|
136 |
+
|
137 |
+
```bibtex
|
138 |
+
@inproceedings{pan2023draggan,
|
139 |
+
title={Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold},
|
140 |
+
author={Pan, Xingang and Tewari, Ayush, and Leimk{\"u}hler, Thomas and Liu, Lingjie and Meka, Abhimitra and Theobalt, Christian},
|
141 |
+
booktitle = {ACM SIGGRAPH 2023 Conference Proceedings},
|
142 |
+
year={2023}
|
143 |
+
}
|
144 |
+
```
|
__pycache__/legacy.cpython-39.pyc
ADDED
Binary file (15 kB). View file
|
|
arial.ttf
ADDED
Binary file (276 kB). View file
|
|
checkpoints/stylegan2-afhqcat-512x512.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17a83abee464242f8bb40dc6363d33c2fb087066b68fc0147677fdbf21f7a7a9
|
3 |
+
size 363939583
|
checkpoints/stylegan2-car-config-f.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1618eee2b3ce87c4a3849442f7850ef12a478556bff035c8e09ee7e23b3794c
|
3 |
+
size 364027523
|
checkpoints/stylegan2-cat-config-f.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:08940fd616cfaf6bc7b0286b5d1a0b3f70febb26e136d64716c8d3f5e9bd3883
|
3 |
+
size 357418027
|
checkpoints/stylegan2-ffhq-512x512.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d2b1c92f41ce8a64c55f7a75ef06c4c0eef9e17b1eb29aae8c10fb37b3e60478
|
3 |
+
size 363939580
|
checkpoints/stylegan2-lhq-256x256.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae2c33d456ff3f274d472da9849c9bb515b7f0032af0d713d0d2d7f42b7942dc
|
3 |
+
size 357314829
|
checkpoints/stylegan2_dogs_1024_pytorch.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9e93090d02916165a602c728c0e37458fc0c58fbc58e4d75bcd096bb81c7e8c
|
3 |
+
size 381630441
|
checkpoints/stylegan2_elephants_512_pytorch.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56504894f6d7121959af78a74148cf1e9d858e3710312efb11c41dbf27684363
|
3 |
+
size 363965313
|
checkpoints/stylegan2_horses_256_pytorch.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f100dc32e2731a293f3b31a9038416f72aa5cc30555b3315a82e19c065f81b0c
|
3 |
+
size 357336721
|
checkpoints/stylegan2_lions_512_pytorch.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a01ff8344521171b1a2eff1e9a51c1acbc48221bdc2594919187f66a3942bcc
|
3 |
+
size 363965313
|
checkpoints/stylegan_human_v2_512.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7af33f638e3b3aba7bf99456eec3e9d4a022d8a7dd67683a3605a7dd37665a3b
|
3 |
+
size 352745981
|
dnnlib/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (187 Bytes). View file
|
|
dnnlib/__pycache__/util.cpython-39.pyc
ADDED
Binary file (14 kB). View file
|
|
dnnlib/util.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utility classes and functions."""
|
10 |
+
|
11 |
+
import ctypes
|
12 |
+
import fnmatch
|
13 |
+
import importlib
|
14 |
+
import inspect
|
15 |
+
import numpy as np
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import sys
|
19 |
+
import types
|
20 |
+
import io
|
21 |
+
import pickle
|
22 |
+
import re
|
23 |
+
import requests
|
24 |
+
import html
|
25 |
+
import hashlib
|
26 |
+
import glob
|
27 |
+
import tempfile
|
28 |
+
import urllib
|
29 |
+
import urllib.request
|
30 |
+
import uuid
|
31 |
+
|
32 |
+
from distutils.util import strtobool
|
33 |
+
from typing import Any, List, Tuple, Union
|
34 |
+
|
35 |
+
|
36 |
+
# Util classes
|
37 |
+
# ------------------------------------------------------------------------------------------
|
38 |
+
|
39 |
+
|
40 |
+
class EasyDict(dict):
|
41 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
42 |
+
|
43 |
+
def __getattr__(self, name: str) -> Any:
|
44 |
+
try:
|
45 |
+
return self[name]
|
46 |
+
except KeyError:
|
47 |
+
raise AttributeError(name)
|
48 |
+
|
49 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
50 |
+
self[name] = value
|
51 |
+
|
52 |
+
def __delattr__(self, name: str) -> None:
|
53 |
+
del self[name]
|
54 |
+
|
55 |
+
|
56 |
+
class Logger(object):
|
57 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
58 |
+
|
59 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
60 |
+
self.file = None
|
61 |
+
|
62 |
+
if file_name is not None:
|
63 |
+
self.file = open(file_name, file_mode)
|
64 |
+
|
65 |
+
self.should_flush = should_flush
|
66 |
+
self.stdout = sys.stdout
|
67 |
+
self.stderr = sys.stderr
|
68 |
+
|
69 |
+
sys.stdout = self
|
70 |
+
sys.stderr = self
|
71 |
+
|
72 |
+
def __enter__(self) -> "Logger":
|
73 |
+
return self
|
74 |
+
|
75 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
76 |
+
self.close()
|
77 |
+
|
78 |
+
def write(self, text: Union[str, bytes]) -> None:
|
79 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
80 |
+
if isinstance(text, bytes):
|
81 |
+
text = text.decode()
|
82 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
83 |
+
return
|
84 |
+
|
85 |
+
if self.file is not None:
|
86 |
+
self.file.write(text)
|
87 |
+
|
88 |
+
self.stdout.write(text)
|
89 |
+
|
90 |
+
if self.should_flush:
|
91 |
+
self.flush()
|
92 |
+
|
93 |
+
def flush(self) -> None:
|
94 |
+
"""Flush written text to both stdout and a file, if open."""
|
95 |
+
if self.file is not None:
|
96 |
+
self.file.flush()
|
97 |
+
|
98 |
+
self.stdout.flush()
|
99 |
+
|
100 |
+
def close(self) -> None:
|
101 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
102 |
+
self.flush()
|
103 |
+
|
104 |
+
# if using multiple loggers, prevent closing in wrong order
|
105 |
+
if sys.stdout is self:
|
106 |
+
sys.stdout = self.stdout
|
107 |
+
if sys.stderr is self:
|
108 |
+
sys.stderr = self.stderr
|
109 |
+
|
110 |
+
if self.file is not None:
|
111 |
+
self.file.close()
|
112 |
+
self.file = None
|
113 |
+
|
114 |
+
|
115 |
+
# Cache directories
|
116 |
+
# ------------------------------------------------------------------------------------------
|
117 |
+
|
118 |
+
_dnnlib_cache_dir = None
|
119 |
+
|
120 |
+
def set_cache_dir(path: str) -> None:
|
121 |
+
global _dnnlib_cache_dir
|
122 |
+
_dnnlib_cache_dir = path
|
123 |
+
|
124 |
+
def make_cache_dir_path(*paths: str) -> str:
|
125 |
+
if _dnnlib_cache_dir is not None:
|
126 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
127 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
128 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
129 |
+
if 'HOME' in os.environ:
|
130 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
131 |
+
if 'USERPROFILE' in os.environ:
|
132 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
133 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
134 |
+
|
135 |
+
# Small util functions
|
136 |
+
# ------------------------------------------------------------------------------------------
|
137 |
+
|
138 |
+
|
139 |
+
def format_time(seconds: Union[int, float]) -> str:
|
140 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
141 |
+
s = int(np.rint(seconds))
|
142 |
+
|
143 |
+
if s < 60:
|
144 |
+
return "{0}s".format(s)
|
145 |
+
elif s < 60 * 60:
|
146 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
147 |
+
elif s < 24 * 60 * 60:
|
148 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
149 |
+
else:
|
150 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
151 |
+
|
152 |
+
|
153 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
154 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
155 |
+
s = int(np.rint(seconds))
|
156 |
+
|
157 |
+
if s < 60:
|
158 |
+
return "{0}s".format(s)
|
159 |
+
elif s < 60 * 60:
|
160 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
161 |
+
elif s < 24 * 60 * 60:
|
162 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
163 |
+
else:
|
164 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
165 |
+
|
166 |
+
|
167 |
+
def ask_yes_no(question: str) -> bool:
|
168 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
169 |
+
while True:
|
170 |
+
try:
|
171 |
+
print("{0} [y/n]".format(question))
|
172 |
+
return strtobool(input().lower())
|
173 |
+
except ValueError:
|
174 |
+
pass
|
175 |
+
|
176 |
+
|
177 |
+
def tuple_product(t: Tuple) -> Any:
|
178 |
+
"""Calculate the product of the tuple elements."""
|
179 |
+
result = 1
|
180 |
+
|
181 |
+
for v in t:
|
182 |
+
result *= v
|
183 |
+
|
184 |
+
return result
|
185 |
+
|
186 |
+
|
187 |
+
_str_to_ctype = {
|
188 |
+
"uint8": ctypes.c_ubyte,
|
189 |
+
"uint16": ctypes.c_uint16,
|
190 |
+
"uint32": ctypes.c_uint32,
|
191 |
+
"uint64": ctypes.c_uint64,
|
192 |
+
"int8": ctypes.c_byte,
|
193 |
+
"int16": ctypes.c_int16,
|
194 |
+
"int32": ctypes.c_int32,
|
195 |
+
"int64": ctypes.c_int64,
|
196 |
+
"float32": ctypes.c_float,
|
197 |
+
"float64": ctypes.c_double
|
198 |
+
}
|
199 |
+
|
200 |
+
|
201 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
202 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
203 |
+
type_str = None
|
204 |
+
|
205 |
+
if isinstance(type_obj, str):
|
206 |
+
type_str = type_obj
|
207 |
+
elif hasattr(type_obj, "__name__"):
|
208 |
+
type_str = type_obj.__name__
|
209 |
+
elif hasattr(type_obj, "name"):
|
210 |
+
type_str = type_obj.name
|
211 |
+
else:
|
212 |
+
raise RuntimeError("Cannot infer type name from input")
|
213 |
+
|
214 |
+
assert type_str in _str_to_ctype.keys()
|
215 |
+
|
216 |
+
my_dtype = np.dtype(type_str)
|
217 |
+
my_ctype = _str_to_ctype[type_str]
|
218 |
+
|
219 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
220 |
+
|
221 |
+
return my_dtype, my_ctype
|
222 |
+
|
223 |
+
|
224 |
+
def is_pickleable(obj: Any) -> bool:
|
225 |
+
try:
|
226 |
+
with io.BytesIO() as stream:
|
227 |
+
pickle.dump(obj, stream)
|
228 |
+
return True
|
229 |
+
except:
|
230 |
+
return False
|
231 |
+
|
232 |
+
|
233 |
+
# Functionality to import modules/objects by name, and call functions by name
|
234 |
+
# ------------------------------------------------------------------------------------------
|
235 |
+
|
236 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
237 |
+
"""Searches for the underlying module behind the name to some python object.
|
238 |
+
Returns the module and the object name (original name with module part removed)."""
|
239 |
+
|
240 |
+
# allow convenience shorthands, substitute them by full names
|
241 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
242 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
243 |
+
|
244 |
+
# list alternatives for (module_name, local_obj_name)
|
245 |
+
parts = obj_name.split(".")
|
246 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
247 |
+
|
248 |
+
# try each alternative in turn
|
249 |
+
for module_name, local_obj_name in name_pairs:
|
250 |
+
try:
|
251 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
252 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
253 |
+
return module, local_obj_name
|
254 |
+
except:
|
255 |
+
pass
|
256 |
+
|
257 |
+
# maybe some of the modules themselves contain errors?
|
258 |
+
for module_name, _local_obj_name in name_pairs:
|
259 |
+
try:
|
260 |
+
importlib.import_module(module_name) # may raise ImportError
|
261 |
+
except ImportError:
|
262 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
263 |
+
raise
|
264 |
+
|
265 |
+
# maybe the requested attribute is missing?
|
266 |
+
for module_name, local_obj_name in name_pairs:
|
267 |
+
try:
|
268 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
269 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
270 |
+
except ImportError:
|
271 |
+
pass
|
272 |
+
|
273 |
+
# we are out of luck, but we have no idea why
|
274 |
+
raise ImportError(obj_name)
|
275 |
+
|
276 |
+
|
277 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
278 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
279 |
+
if obj_name == '':
|
280 |
+
return module
|
281 |
+
obj = module
|
282 |
+
for part in obj_name.split("."):
|
283 |
+
obj = getattr(obj, part)
|
284 |
+
return obj
|
285 |
+
|
286 |
+
|
287 |
+
def get_obj_by_name(name: str) -> Any:
|
288 |
+
"""Finds the python object with the given name."""
|
289 |
+
module, obj_name = get_module_from_obj_name(name)
|
290 |
+
return get_obj_from_module(module, obj_name)
|
291 |
+
|
292 |
+
|
293 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
294 |
+
"""Finds the python object with the given name and calls it as a function."""
|
295 |
+
assert func_name is not None
|
296 |
+
func_obj = get_obj_by_name(func_name)
|
297 |
+
assert callable(func_obj)
|
298 |
+
return func_obj(*args, **kwargs)
|
299 |
+
|
300 |
+
|
301 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
302 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
303 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
304 |
+
|
305 |
+
|
306 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
307 |
+
"""Get the directory path of the module containing the given object name."""
|
308 |
+
module, _ = get_module_from_obj_name(obj_name)
|
309 |
+
return os.path.dirname(inspect.getfile(module))
|
310 |
+
|
311 |
+
|
312 |
+
def is_top_level_function(obj: Any) -> bool:
|
313 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
314 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
315 |
+
|
316 |
+
|
317 |
+
def get_top_level_function_name(obj: Any) -> str:
|
318 |
+
"""Return the fully-qualified name of a top-level function."""
|
319 |
+
assert is_top_level_function(obj)
|
320 |
+
module = obj.__module__
|
321 |
+
if module == '__main__':
|
322 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
323 |
+
return module + "." + obj.__name__
|
324 |
+
|
325 |
+
|
326 |
+
# File system helpers
|
327 |
+
# ------------------------------------------------------------------------------------------
|
328 |
+
|
329 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
330 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
331 |
+
Returns list of tuples containing both absolute and relative paths."""
|
332 |
+
assert os.path.isdir(dir_path)
|
333 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
334 |
+
|
335 |
+
if ignores is None:
|
336 |
+
ignores = []
|
337 |
+
|
338 |
+
result = []
|
339 |
+
|
340 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
341 |
+
for ignore_ in ignores:
|
342 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
343 |
+
|
344 |
+
# dirs need to be edited in-place
|
345 |
+
for d in dirs_to_remove:
|
346 |
+
dirs.remove(d)
|
347 |
+
|
348 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
349 |
+
|
350 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
351 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
352 |
+
|
353 |
+
if add_base_to_relative:
|
354 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
355 |
+
|
356 |
+
assert len(absolute_paths) == len(relative_paths)
|
357 |
+
result += zip(absolute_paths, relative_paths)
|
358 |
+
|
359 |
+
return result
|
360 |
+
|
361 |
+
|
362 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
363 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
364 |
+
Will create all necessary directories."""
|
365 |
+
for file in files:
|
366 |
+
target_dir_name = os.path.dirname(file[1])
|
367 |
+
|
368 |
+
# will create all intermediate-level directories
|
369 |
+
if not os.path.exists(target_dir_name):
|
370 |
+
os.makedirs(target_dir_name)
|
371 |
+
|
372 |
+
shutil.copyfile(file[0], file[1])
|
373 |
+
|
374 |
+
|
375 |
+
# URL helpers
|
376 |
+
# ------------------------------------------------------------------------------------------
|
377 |
+
|
378 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
379 |
+
"""Determine whether the given object is a valid URL string."""
|
380 |
+
if not isinstance(obj, str) or not "://" in obj:
|
381 |
+
return False
|
382 |
+
if allow_file_urls and obj.startswith('file://'):
|
383 |
+
return True
|
384 |
+
try:
|
385 |
+
res = requests.compat.urlparse(obj)
|
386 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
387 |
+
return False
|
388 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
389 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
390 |
+
return False
|
391 |
+
except:
|
392 |
+
return False
|
393 |
+
return True
|
394 |
+
|
395 |
+
|
396 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
397 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
398 |
+
assert num_attempts >= 1
|
399 |
+
assert not (return_filename and (not cache))
|
400 |
+
|
401 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
402 |
+
if not re.match('^[a-z]+://', url):
|
403 |
+
return url if return_filename else open(url, "rb")
|
404 |
+
|
405 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
406 |
+
# arise on Windows:
|
407 |
+
#
|
408 |
+
# file:///c:/foo.txt
|
409 |
+
#
|
410 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
411 |
+
# invalid. Drop the forward slash for such pathnames.
|
412 |
+
#
|
413 |
+
# If you touch this code path, you should test it on both Linux and
|
414 |
+
# Windows.
|
415 |
+
#
|
416 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
417 |
+
# but that converts forward slashes to backslashes and this causes
|
418 |
+
# its own set of problems.
|
419 |
+
if url.startswith('file://'):
|
420 |
+
filename = urllib.parse.urlparse(url).path
|
421 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
422 |
+
filename = filename[1:]
|
423 |
+
return filename if return_filename else open(filename, "rb")
|
424 |
+
|
425 |
+
assert is_url(url)
|
426 |
+
|
427 |
+
# Lookup from cache.
|
428 |
+
if cache_dir is None:
|
429 |
+
cache_dir = make_cache_dir_path('downloads')
|
430 |
+
|
431 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
432 |
+
if cache:
|
433 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
434 |
+
if len(cache_files) == 1:
|
435 |
+
filename = cache_files[0]
|
436 |
+
return filename if return_filename else open(filename, "rb")
|
437 |
+
|
438 |
+
# Download.
|
439 |
+
url_name = None
|
440 |
+
url_data = None
|
441 |
+
with requests.Session() as session:
|
442 |
+
if verbose:
|
443 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
444 |
+
for attempts_left in reversed(range(num_attempts)):
|
445 |
+
try:
|
446 |
+
with session.get(url) as res:
|
447 |
+
res.raise_for_status()
|
448 |
+
if len(res.content) == 0:
|
449 |
+
raise IOError("No data received")
|
450 |
+
|
451 |
+
if len(res.content) < 8192:
|
452 |
+
content_str = res.content.decode("utf-8")
|
453 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
454 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
455 |
+
if len(links) == 1:
|
456 |
+
url = requests.compat.urljoin(url, links[0])
|
457 |
+
raise IOError("Google Drive virus checker nag")
|
458 |
+
if "Google Drive - Quota exceeded" in content_str:
|
459 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
460 |
+
|
461 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
462 |
+
url_name = match[1] if match else url
|
463 |
+
url_data = res.content
|
464 |
+
if verbose:
|
465 |
+
print(" done")
|
466 |
+
break
|
467 |
+
except KeyboardInterrupt:
|
468 |
+
raise
|
469 |
+
except:
|
470 |
+
if not attempts_left:
|
471 |
+
if verbose:
|
472 |
+
print(" failed")
|
473 |
+
raise
|
474 |
+
if verbose:
|
475 |
+
print(".", end="", flush=True)
|
476 |
+
|
477 |
+
# Save to cache.
|
478 |
+
if cache:
|
479 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
480 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
481 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
482 |
+
os.makedirs(cache_dir, exist_ok=True)
|
483 |
+
with open(temp_file, "wb") as f:
|
484 |
+
f.write(url_data)
|
485 |
+
os.replace(temp_file, cache_file) # atomic
|
486 |
+
if return_filename:
|
487 |
+
return cache_file
|
488 |
+
|
489 |
+
# Return data as file object.
|
490 |
+
assert not return_filename
|
491 |
+
return io.BytesIO(url_data)
|
environment.yml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: stylegan3
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
dependencies:
|
6 |
+
- python >= 3.8
|
7 |
+
- pip
|
8 |
+
- numpy>=1.25
|
9 |
+
- click>=8.0
|
10 |
+
- pillow=9.4.0
|
11 |
+
#- scipy=1.11.0
|
12 |
+
- pytorch>=2.0.1
|
13 |
+
- torchvision>=0.15.2
|
14 |
+
#- cudatoolkit=11.1
|
15 |
+
- requests=2.26.0
|
16 |
+
- tqdm=4.62.2
|
17 |
+
- ninja=1.10.2
|
18 |
+
- matplotlib=3.4.2
|
19 |
+
- imageio=2.9.0
|
20 |
+
- pip:
|
21 |
+
- imgui==2.0.0
|
22 |
+
- glfw==2.6.1
|
23 |
+
- gradio==3.35.2
|
24 |
+
- pyopengl==3.1.5
|
25 |
+
- imageio-ffmpeg==0.4.3
|
26 |
+
# pyspng is currently broken on MacOS (see https://github.com/nurpax/pyspng/pull/6 for instance)
|
27 |
+
- pyspng-seunglab
|
gen_images.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Generate images using pretrained network pickle."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
|
15 |
+
import click
|
16 |
+
import dnnlib
|
17 |
+
import numpy as np
|
18 |
+
import PIL.Image
|
19 |
+
import torch
|
20 |
+
|
21 |
+
import legacy
|
22 |
+
|
23 |
+
#----------------------------------------------------------------------------
|
24 |
+
|
25 |
+
def parse_range(s: Union[str, List]) -> List[int]:
|
26 |
+
'''Parse a comma separated list of numbers or ranges and return a list of ints.
|
27 |
+
|
28 |
+
Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
|
29 |
+
'''
|
30 |
+
if isinstance(s, list): return s
|
31 |
+
ranges = []
|
32 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
33 |
+
for p in s.split(','):
|
34 |
+
m = range_re.match(p)
|
35 |
+
if m:
|
36 |
+
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
|
37 |
+
else:
|
38 |
+
ranges.append(int(p))
|
39 |
+
return ranges
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
42 |
+
|
43 |
+
def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
|
44 |
+
'''Parse a floating point 2-vector of syntax 'a,b'.
|
45 |
+
|
46 |
+
Example:
|
47 |
+
'0,1' returns (0,1)
|
48 |
+
'''
|
49 |
+
if isinstance(s, tuple): return s
|
50 |
+
parts = s.split(',')
|
51 |
+
if len(parts) == 2:
|
52 |
+
return (float(parts[0]), float(parts[1]))
|
53 |
+
raise ValueError(f'cannot parse 2-vector {s}')
|
54 |
+
|
55 |
+
#----------------------------------------------------------------------------
|
56 |
+
|
57 |
+
def make_transform(translate: Tuple[float,float], angle: float):
|
58 |
+
m = np.eye(3)
|
59 |
+
s = np.sin(angle/360.0*np.pi*2)
|
60 |
+
c = np.cos(angle/360.0*np.pi*2)
|
61 |
+
m[0][0] = c
|
62 |
+
m[0][1] = s
|
63 |
+
m[0][2] = translate[0]
|
64 |
+
m[1][0] = -s
|
65 |
+
m[1][1] = c
|
66 |
+
m[1][2] = translate[1]
|
67 |
+
return m
|
68 |
+
|
69 |
+
#----------------------------------------------------------------------------
|
70 |
+
|
71 |
+
@click.command()
|
72 |
+
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
|
73 |
+
@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
|
74 |
+
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
|
75 |
+
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
|
76 |
+
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
|
77 |
+
@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
|
78 |
+
@click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
|
79 |
+
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
|
80 |
+
def generate_images(
|
81 |
+
network_pkl: str,
|
82 |
+
seeds: List[int],
|
83 |
+
truncation_psi: float,
|
84 |
+
noise_mode: str,
|
85 |
+
outdir: str,
|
86 |
+
translate: Tuple[float,float],
|
87 |
+
rotate: float,
|
88 |
+
class_idx: Optional[int]
|
89 |
+
):
|
90 |
+
"""Generate images using pretrained network pickle.
|
91 |
+
|
92 |
+
Examples:
|
93 |
+
|
94 |
+
\b
|
95 |
+
# Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
|
96 |
+
python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
|
97 |
+
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
|
98 |
+
|
99 |
+
\b
|
100 |
+
# Generate uncurated images with truncation using the MetFaces-U dataset
|
101 |
+
python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
|
102 |
+
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
|
103 |
+
"""
|
104 |
+
|
105 |
+
print('Loading networks from "%s"...' % network_pkl)
|
106 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
107 |
+
dtype = torch.float32 if device.type == 'mps' else torch.float64
|
108 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
109 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
|
110 |
+
# import pickle
|
111 |
+
# G = legacy.load_network_pkl(f)
|
112 |
+
# output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb')
|
113 |
+
# pickle.dump(G, output)
|
114 |
+
|
115 |
+
os.makedirs(outdir, exist_ok=True)
|
116 |
+
|
117 |
+
# Labels.
|
118 |
+
label = torch.zeros([1, G.c_dim], device=device)
|
119 |
+
if G.c_dim != 0:
|
120 |
+
if class_idx is None:
|
121 |
+
raise click.ClickException('Must specify class label with --class when using a conditional network')
|
122 |
+
label[:, class_idx] = 1
|
123 |
+
else:
|
124 |
+
if class_idx is not None:
|
125 |
+
print ('warn: --class=lbl ignored when running on an unconditional network')
|
126 |
+
|
127 |
+
# Generate images.
|
128 |
+
for seed_idx, seed in enumerate(seeds):
|
129 |
+
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
|
130 |
+
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
|
131 |
+
|
132 |
+
# Construct an inverse rotation/translation matrix and pass to the generator. The
|
133 |
+
# generator expects this matrix as an inverse to avoid potentially failing numerical
|
134 |
+
# operations in the network.
|
135 |
+
if hasattr(G.synthesis, 'input'):
|
136 |
+
m = make_transform(translate, rotate)
|
137 |
+
m = np.linalg.inv(m)
|
138 |
+
G.synthesis.input.transform.copy_(torch.from_numpy(m))
|
139 |
+
|
140 |
+
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
|
141 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
142 |
+
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
|
143 |
+
|
144 |
+
|
145 |
+
#----------------------------------------------------------------------------
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
generate_images() # pylint: disable=no-value-for-parameter
|
149 |
+
|
150 |
+
#----------------------------------------------------------------------------
|
gradio_utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
|
2 |
+
get_latest_points_pair, get_valid_mask,
|
3 |
+
on_change_single_global_state)
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'draw_mask_on_image', 'draw_points_on_image',
|
7 |
+
'on_change_single_global_state', 'get_latest_points_pair',
|
8 |
+
'get_valid_mask', 'ImageMask'
|
9 |
+
]
|
gradio_utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (374 Bytes). View file
|
|
gradio_utils/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.67 kB). View file
|
|
gradio_utils/utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
|
5 |
+
|
6 |
+
class ImageMask(gr.components.Image):
|
7 |
+
"""
|
8 |
+
Sets: source="canvas", tool="sketch"
|
9 |
+
"""
|
10 |
+
|
11 |
+
is_template = True
|
12 |
+
|
13 |
+
def __init__(self, **kwargs):
|
14 |
+
super().__init__(source="upload",
|
15 |
+
tool="sketch",
|
16 |
+
interactive=False,
|
17 |
+
**kwargs)
|
18 |
+
|
19 |
+
def preprocess(self, x):
|
20 |
+
if x is None:
|
21 |
+
return x
|
22 |
+
if self.tool == "sketch" and self.source in ["upload", "webcam"
|
23 |
+
] and type(x) != dict:
|
24 |
+
decode_image = gr.processing_utils.decode_base64_to_image(x)
|
25 |
+
width, height = decode_image.size
|
26 |
+
mask = np.ones((height, width, 4), dtype=np.uint8)
|
27 |
+
mask[..., -1] = 255
|
28 |
+
mask = self.postprocess(mask)
|
29 |
+
x = {'image': x, 'mask': mask}
|
30 |
+
return super().preprocess(x)
|
31 |
+
|
32 |
+
|
33 |
+
def get_valid_mask(mask: np.ndarray):
|
34 |
+
"""Convert mask from gr.Image(0 to 255, RGBA) to binary mask.
|
35 |
+
"""
|
36 |
+
if mask.ndim == 3:
|
37 |
+
mask_pil = Image.fromarray(mask).convert('L')
|
38 |
+
mask = np.array(mask_pil)
|
39 |
+
if mask.max() == 255:
|
40 |
+
mask = mask / 255
|
41 |
+
return mask
|
42 |
+
|
43 |
+
|
44 |
+
def draw_points_on_image(image,
|
45 |
+
points,
|
46 |
+
curr_point=None,
|
47 |
+
highlight_all=True,
|
48 |
+
radius_scale=0.01):
|
49 |
+
overlay_rgba = Image.new("RGBA", image.size, 0)
|
50 |
+
overlay_draw = ImageDraw.Draw(overlay_rgba)
|
51 |
+
for point_key, point in points.items():
|
52 |
+
if ((curr_point is not None and curr_point == point_key)
|
53 |
+
or highlight_all):
|
54 |
+
p_color = (255, 0, 0)
|
55 |
+
t_color = (0, 0, 255)
|
56 |
+
|
57 |
+
else:
|
58 |
+
p_color = (255, 0, 0, 35)
|
59 |
+
t_color = (0, 0, 255, 35)
|
60 |
+
|
61 |
+
rad_draw = int(image.size[0] * radius_scale)
|
62 |
+
|
63 |
+
p_start = point.get("start_temp", point["start"])
|
64 |
+
p_target = point["target"]
|
65 |
+
|
66 |
+
if p_start is not None and p_target is not None:
|
67 |
+
p_draw = int(p_start[0]), int(p_start[1])
|
68 |
+
t_draw = int(p_target[0]), int(p_target[1])
|
69 |
+
|
70 |
+
overlay_draw.line(
|
71 |
+
(p_draw[0], p_draw[1], t_draw[0], t_draw[1]),
|
72 |
+
fill=(255, 255, 0),
|
73 |
+
width=2,
|
74 |
+
)
|
75 |
+
|
76 |
+
if p_start is not None:
|
77 |
+
p_draw = int(p_start[0]), int(p_start[1])
|
78 |
+
overlay_draw.ellipse(
|
79 |
+
(
|
80 |
+
p_draw[0] - rad_draw,
|
81 |
+
p_draw[1] - rad_draw,
|
82 |
+
p_draw[0] + rad_draw,
|
83 |
+
p_draw[1] + rad_draw,
|
84 |
+
),
|
85 |
+
fill=p_color,
|
86 |
+
)
|
87 |
+
|
88 |
+
if curr_point is not None and curr_point == point_key:
|
89 |
+
# overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0))
|
90 |
+
overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0))
|
91 |
+
|
92 |
+
if p_target is not None:
|
93 |
+
t_draw = int(p_target[0]), int(p_target[1])
|
94 |
+
overlay_draw.ellipse(
|
95 |
+
(
|
96 |
+
t_draw[0] - rad_draw,
|
97 |
+
t_draw[1] - rad_draw,
|
98 |
+
t_draw[0] + rad_draw,
|
99 |
+
t_draw[1] + rad_draw,
|
100 |
+
),
|
101 |
+
fill=t_color,
|
102 |
+
)
|
103 |
+
|
104 |
+
if curr_point is not None and curr_point == point_key:
|
105 |
+
# overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0))
|
106 |
+
overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0))
|
107 |
+
|
108 |
+
return Image.alpha_composite(image.convert("RGBA"),
|
109 |
+
overlay_rgba).convert("RGB")
|
110 |
+
|
111 |
+
|
112 |
+
def draw_mask_on_image(image, mask):
|
113 |
+
im_mask = np.uint8(mask * 255)
|
114 |
+
im_mask_rgba = np.concatenate(
|
115 |
+
(
|
116 |
+
np.tile(im_mask[..., None], [1, 1, 3]),
|
117 |
+
45 * np.ones(
|
118 |
+
(im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8),
|
119 |
+
),
|
120 |
+
axis=-1,
|
121 |
+
)
|
122 |
+
im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA")
|
123 |
+
|
124 |
+
return Image.alpha_composite(image.convert("RGBA"),
|
125 |
+
im_mask_rgba).convert("RGB")
|
126 |
+
|
127 |
+
|
128 |
+
def on_change_single_global_state(keys,
|
129 |
+
value,
|
130 |
+
global_state,
|
131 |
+
map_transform=None):
|
132 |
+
if map_transform is not None:
|
133 |
+
value = map_transform(value)
|
134 |
+
|
135 |
+
curr_state = global_state
|
136 |
+
if isinstance(keys, str):
|
137 |
+
last_key = keys
|
138 |
+
|
139 |
+
else:
|
140 |
+
for k in keys[:-1]:
|
141 |
+
curr_state = curr_state[k]
|
142 |
+
|
143 |
+
last_key = keys[-1]
|
144 |
+
|
145 |
+
curr_state[last_key] = value
|
146 |
+
return global_state
|
147 |
+
|
148 |
+
|
149 |
+
def get_latest_points_pair(points_dict):
|
150 |
+
if not points_dict:
|
151 |
+
return None
|
152 |
+
point_idx = list(points_dict.keys())
|
153 |
+
latest_point_idx = max(point_idx)
|
154 |
+
return latest_point_idx
|
gui_utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
gui_utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (123 Bytes). View file
|
|
gui_utils/__pycache__/gl_utils.cpython-39.pyc
ADDED
Binary file (12.7 kB). View file
|
|
gui_utils/__pycache__/glfw_window.cpython-39.pyc
ADDED
Binary file (7.72 kB). View file
|
|
gui_utils/__pycache__/imgui_utils.cpython-39.pyc
ADDED
Binary file (5.78 kB). View file
|
|
gui_utils/__pycache__/imgui_window.cpython-39.pyc
ADDED
Binary file (3.93 kB). View file
|
|
gui_utils/__pycache__/text_utils.cpython-39.pyc
ADDED
Binary file (5.09 kB). View file
|
|
gui_utils/gl_utils.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
import functools
|
12 |
+
import contextlib
|
13 |
+
import numpy as np
|
14 |
+
import OpenGL.GL as gl
|
15 |
+
import OpenGL.GL.ARB.texture_float
|
16 |
+
import dnnlib
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
def init_egl():
|
21 |
+
assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
|
22 |
+
import OpenGL.EGL as egl
|
23 |
+
import ctypes
|
24 |
+
|
25 |
+
# Initialize EGL.
|
26 |
+
display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
|
27 |
+
assert display != egl.EGL_NO_DISPLAY
|
28 |
+
major = ctypes.c_int32()
|
29 |
+
minor = ctypes.c_int32()
|
30 |
+
ok = egl.eglInitialize(display, major, minor)
|
31 |
+
assert ok
|
32 |
+
assert major.value * 10 + minor.value >= 14
|
33 |
+
|
34 |
+
# Choose config.
|
35 |
+
config_attribs = [
|
36 |
+
egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
|
37 |
+
egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
|
38 |
+
egl.EGL_NONE
|
39 |
+
]
|
40 |
+
configs = (ctypes.c_int32 * 1)()
|
41 |
+
num_configs = ctypes.c_int32()
|
42 |
+
ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
|
43 |
+
assert ok
|
44 |
+
assert num_configs.value == 1
|
45 |
+
config = configs[0]
|
46 |
+
|
47 |
+
# Create dummy pbuffer surface.
|
48 |
+
surface_attribs = [
|
49 |
+
egl.EGL_WIDTH, 1,
|
50 |
+
egl.EGL_HEIGHT, 1,
|
51 |
+
egl.EGL_NONE
|
52 |
+
]
|
53 |
+
surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
|
54 |
+
assert surface != egl.EGL_NO_SURFACE
|
55 |
+
|
56 |
+
# Setup GL context.
|
57 |
+
ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
|
58 |
+
assert ok
|
59 |
+
context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
|
60 |
+
assert context != egl.EGL_NO_CONTEXT
|
61 |
+
ok = egl.eglMakeCurrent(display, surface, surface, context)
|
62 |
+
assert ok
|
63 |
+
|
64 |
+
#----------------------------------------------------------------------------
|
65 |
+
|
66 |
+
_texture_formats = {
|
67 |
+
('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
|
68 |
+
('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
|
69 |
+
('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
|
70 |
+
('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
|
71 |
+
('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
|
72 |
+
('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
|
73 |
+
('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
|
74 |
+
('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
|
75 |
+
}
|
76 |
+
|
77 |
+
def get_texture_format(dtype, channels):
|
78 |
+
return _texture_formats[(np.dtype(dtype).name, int(channels))]
|
79 |
+
|
80 |
+
#----------------------------------------------------------------------------
|
81 |
+
|
82 |
+
def prepare_texture_data(image):
|
83 |
+
image = np.asarray(image)
|
84 |
+
if image.ndim == 2:
|
85 |
+
image = image[:, :, np.newaxis]
|
86 |
+
if image.dtype.name == 'float64':
|
87 |
+
image = image.astype('float32')
|
88 |
+
return image
|
89 |
+
|
90 |
+
#----------------------------------------------------------------------------
|
91 |
+
|
92 |
+
def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
|
93 |
+
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
|
94 |
+
zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
|
95 |
+
align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
|
96 |
+
image = prepare_texture_data(image)
|
97 |
+
height, width, channels = image.shape
|
98 |
+
size = zoom * [width, height]
|
99 |
+
pos = pos - size * align
|
100 |
+
if rint:
|
101 |
+
pos = np.rint(pos)
|
102 |
+
fmt = get_texture_format(image.dtype, channels)
|
103 |
+
|
104 |
+
gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
|
105 |
+
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
|
106 |
+
gl.glRasterPos2f(pos[0], pos[1])
|
107 |
+
gl.glPixelZoom(zoom[0], -zoom[1])
|
108 |
+
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
|
109 |
+
gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
|
110 |
+
gl.glPopClientAttrib()
|
111 |
+
gl.glPopAttrib()
|
112 |
+
|
113 |
+
#----------------------------------------------------------------------------
|
114 |
+
|
115 |
+
def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
|
116 |
+
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
|
117 |
+
dtype = np.dtype(dtype)
|
118 |
+
fmt = get_texture_format(dtype, channels)
|
119 |
+
image = np.empty([height, width, channels], dtype=dtype)
|
120 |
+
|
121 |
+
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
|
122 |
+
gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
|
123 |
+
gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
|
124 |
+
gl.glPopClientAttrib()
|
125 |
+
return np.flipud(image)
|
126 |
+
|
127 |
+
#----------------------------------------------------------------------------
|
128 |
+
|
129 |
+
class Texture:
|
130 |
+
def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
|
131 |
+
self.gl_id = None
|
132 |
+
self.bilinear = bilinear
|
133 |
+
self.mipmap = mipmap
|
134 |
+
|
135 |
+
# Determine size and dtype.
|
136 |
+
if image is not None:
|
137 |
+
image = prepare_texture_data(image)
|
138 |
+
self.height, self.width, self.channels = image.shape
|
139 |
+
self.dtype = image.dtype
|
140 |
+
else:
|
141 |
+
assert width is not None and height is not None
|
142 |
+
self.width = width
|
143 |
+
self.height = height
|
144 |
+
self.channels = channels if channels is not None else 3
|
145 |
+
self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
|
146 |
+
|
147 |
+
# Validate size and dtype.
|
148 |
+
assert isinstance(self.width, int) and self.width >= 0
|
149 |
+
assert isinstance(self.height, int) and self.height >= 0
|
150 |
+
assert isinstance(self.channels, int) and self.channels >= 1
|
151 |
+
assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
|
152 |
+
|
153 |
+
# Create texture object.
|
154 |
+
self.gl_id = gl.glGenTextures(1)
|
155 |
+
with self.bind():
|
156 |
+
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
157 |
+
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
158 |
+
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
|
159 |
+
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
|
160 |
+
self.update(image)
|
161 |
+
|
162 |
+
def delete(self):
|
163 |
+
if self.gl_id is not None:
|
164 |
+
gl.glDeleteTextures([self.gl_id])
|
165 |
+
self.gl_id = None
|
166 |
+
|
167 |
+
def __del__(self):
|
168 |
+
try:
|
169 |
+
self.delete()
|
170 |
+
except:
|
171 |
+
pass
|
172 |
+
|
173 |
+
@contextlib.contextmanager
|
174 |
+
def bind(self):
|
175 |
+
prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
|
176 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
|
177 |
+
yield
|
178 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
|
179 |
+
|
180 |
+
def update(self, image):
|
181 |
+
if image is not None:
|
182 |
+
image = prepare_texture_data(image)
|
183 |
+
assert self.is_compatible(image=image)
|
184 |
+
with self.bind():
|
185 |
+
fmt = get_texture_format(self.dtype, self.channels)
|
186 |
+
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
|
187 |
+
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
|
188 |
+
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
|
189 |
+
if self.mipmap:
|
190 |
+
gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
|
191 |
+
gl.glPopClientAttrib()
|
192 |
+
|
193 |
+
def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
|
194 |
+
zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
|
195 |
+
size = zoom * [self.width, self.height]
|
196 |
+
with self.bind():
|
197 |
+
gl.glPushAttrib(gl.GL_ENABLE_BIT)
|
198 |
+
gl.glEnable(gl.GL_TEXTURE_2D)
|
199 |
+
draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
|
200 |
+
gl.glPopAttrib()
|
201 |
+
|
202 |
+
def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
|
203 |
+
if image is not None:
|
204 |
+
if image.ndim != 3:
|
205 |
+
return False
|
206 |
+
ih, iw, ic = image.shape
|
207 |
+
if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
|
208 |
+
return False
|
209 |
+
if width is not None and self.width != width:
|
210 |
+
return False
|
211 |
+
if height is not None and self.height != height:
|
212 |
+
return False
|
213 |
+
if channels is not None and self.channels != channels:
|
214 |
+
return False
|
215 |
+
if dtype is not None and self.dtype != dtype:
|
216 |
+
return False
|
217 |
+
return True
|
218 |
+
|
219 |
+
#----------------------------------------------------------------------------
|
220 |
+
|
221 |
+
class Framebuffer:
|
222 |
+
def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
|
223 |
+
self.texture = texture
|
224 |
+
self.gl_id = None
|
225 |
+
self.gl_color = None
|
226 |
+
self.gl_depth_stencil = None
|
227 |
+
self.msaa = msaa
|
228 |
+
|
229 |
+
# Determine size and dtype.
|
230 |
+
if texture is not None:
|
231 |
+
assert isinstance(self.texture, Texture)
|
232 |
+
self.width = texture.width
|
233 |
+
self.height = texture.height
|
234 |
+
self.channels = texture.channels
|
235 |
+
self.dtype = texture.dtype
|
236 |
+
else:
|
237 |
+
assert width is not None and height is not None
|
238 |
+
self.width = width
|
239 |
+
self.height = height
|
240 |
+
self.channels = channels if channels is not None else 4
|
241 |
+
self.dtype = np.dtype(dtype) if dtype is not None else np.float32
|
242 |
+
|
243 |
+
# Validate size and dtype.
|
244 |
+
assert isinstance(self.width, int) and self.width >= 0
|
245 |
+
assert isinstance(self.height, int) and self.height >= 0
|
246 |
+
assert isinstance(self.channels, int) and self.channels >= 1
|
247 |
+
assert width is None or width == self.width
|
248 |
+
assert height is None or height == self.height
|
249 |
+
assert channels is None or channels == self.channels
|
250 |
+
assert dtype is None or dtype == self.dtype
|
251 |
+
|
252 |
+
# Create framebuffer object.
|
253 |
+
self.gl_id = gl.glGenFramebuffers(1)
|
254 |
+
with self.bind():
|
255 |
+
|
256 |
+
# Setup color buffer.
|
257 |
+
if self.texture is not None:
|
258 |
+
assert self.msaa == 0
|
259 |
+
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
|
260 |
+
else:
|
261 |
+
fmt = get_texture_format(self.dtype, self.channels)
|
262 |
+
self.gl_color = gl.glGenRenderbuffers(1)
|
263 |
+
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
|
264 |
+
gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
|
265 |
+
gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
|
266 |
+
|
267 |
+
# Setup depth/stencil buffer.
|
268 |
+
self.gl_depth_stencil = gl.glGenRenderbuffers(1)
|
269 |
+
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
|
270 |
+
gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
|
271 |
+
gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
|
272 |
+
|
273 |
+
def delete(self):
|
274 |
+
if self.gl_id is not None:
|
275 |
+
gl.glDeleteFramebuffers([self.gl_id])
|
276 |
+
self.gl_id = None
|
277 |
+
if self.gl_color is not None:
|
278 |
+
gl.glDeleteRenderbuffers(1, [self.gl_color])
|
279 |
+
self.gl_color = None
|
280 |
+
if self.gl_depth_stencil is not None:
|
281 |
+
gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
|
282 |
+
self.gl_depth_stencil = None
|
283 |
+
|
284 |
+
def __del__(self):
|
285 |
+
try:
|
286 |
+
self.delete()
|
287 |
+
except:
|
288 |
+
pass
|
289 |
+
|
290 |
+
@contextlib.contextmanager
|
291 |
+
def bind(self):
|
292 |
+
prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
|
293 |
+
prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
|
294 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
|
295 |
+
if self.width is not None and self.height is not None:
|
296 |
+
gl.glViewport(0, 0, self.width, self.height)
|
297 |
+
yield
|
298 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
|
299 |
+
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
|
300 |
+
|
301 |
+
def blit(self, dst=None):
|
302 |
+
assert dst is None or isinstance(dst, Framebuffer)
|
303 |
+
with self.bind():
|
304 |
+
gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
|
305 |
+
gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
|
306 |
+
|
307 |
+
#----------------------------------------------------------------------------
|
308 |
+
|
309 |
+
def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
|
310 |
+
assert vertices.ndim == 2 and vertices.shape[1] == 2
|
311 |
+
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
|
312 |
+
size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
|
313 |
+
color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
|
314 |
+
alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
|
315 |
+
|
316 |
+
gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
|
317 |
+
gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
|
318 |
+
gl.glMatrixMode(gl.GL_MODELVIEW)
|
319 |
+
gl.glPushMatrix()
|
320 |
+
|
321 |
+
gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
|
322 |
+
gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
|
323 |
+
gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
|
324 |
+
gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
|
325 |
+
gl.glTranslate(pos[0], pos[1], 0)
|
326 |
+
gl.glScale(size[0], size[1], 1)
|
327 |
+
gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
|
328 |
+
gl.glDrawArrays(mode, 0, vertices.shape[0])
|
329 |
+
|
330 |
+
gl.glPopMatrix()
|
331 |
+
gl.glPopAttrib()
|
332 |
+
gl.glPopClientAttrib()
|
333 |
+
|
334 |
+
#----------------------------------------------------------------------------
|
335 |
+
|
336 |
+
def draw_arrow(x1, y1, x2, y2, l=10, width=1.0):
|
337 |
+
# Compute the length and angle of the arrow
|
338 |
+
dx = x2 - x1
|
339 |
+
dy = y2 - y1
|
340 |
+
length = math.sqrt(dx**2 + dy**2)
|
341 |
+
if length < l:
|
342 |
+
return
|
343 |
+
angle = math.atan2(dy, dx)
|
344 |
+
|
345 |
+
# Save the current modelview matrix
|
346 |
+
gl.glPushMatrix()
|
347 |
+
|
348 |
+
# Translate and rotate the coordinate system
|
349 |
+
gl.glTranslatef(x1, y1, 0.0)
|
350 |
+
gl.glRotatef(angle * 180.0 / math.pi, 0.0, 0.0, 1.0)
|
351 |
+
|
352 |
+
# Set the line width
|
353 |
+
gl.glLineWidth(width)
|
354 |
+
# gl.glColor3f(0.75, 0.75, 0.75)
|
355 |
+
|
356 |
+
# Begin drawing lines
|
357 |
+
gl.glBegin(gl.GL_LINES)
|
358 |
+
|
359 |
+
# Draw the shaft of the arrow
|
360 |
+
gl.glVertex2f(0.0, 0.0)
|
361 |
+
gl.glVertex2f(length, 0.0)
|
362 |
+
|
363 |
+
# Draw the head of the arrow
|
364 |
+
gl.glVertex2f(length, 0.0)
|
365 |
+
gl.glVertex2f(length - 2 * l, l)
|
366 |
+
gl.glVertex2f(length, 0.0)
|
367 |
+
gl.glVertex2f(length - 2 * l, -l)
|
368 |
+
|
369 |
+
# End drawing lines
|
370 |
+
gl.glEnd()
|
371 |
+
|
372 |
+
# Restore the modelview matrix
|
373 |
+
gl.glPopMatrix()
|
374 |
+
|
375 |
+
#----------------------------------------------------------------------------
|
376 |
+
|
377 |
+
def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
|
378 |
+
assert pos2 is None or size is None
|
379 |
+
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
|
380 |
+
pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
|
381 |
+
size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
|
382 |
+
size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
|
383 |
+
pos = pos - size * align
|
384 |
+
if rint:
|
385 |
+
pos = np.rint(pos)
|
386 |
+
rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
|
387 |
+
rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
|
388 |
+
if np.min(rounding) == 0:
|
389 |
+
rounding *= 0
|
390 |
+
vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
|
391 |
+
draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
|
392 |
+
|
393 |
+
@functools.lru_cache(maxsize=10000)
|
394 |
+
def _setup_rect(rx, ry):
|
395 |
+
t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
|
396 |
+
s = 1 - np.sin(t); c = 1 - np.cos(t)
|
397 |
+
x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
|
398 |
+
y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
|
399 |
+
v = np.stack([x, y], axis=-1).reshape(-1, 2)
|
400 |
+
return v.astype('float32')
|
401 |
+
|
402 |
+
#----------------------------------------------------------------------------
|
403 |
+
|
404 |
+
def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
|
405 |
+
hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
|
406 |
+
vertices = _setup_circle(float(hole))
|
407 |
+
draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
|
408 |
+
|
409 |
+
@functools.lru_cache(maxsize=10000)
|
410 |
+
def _setup_circle(hole):
|
411 |
+
t = np.linspace(0, np.pi * 2, 128)
|
412 |
+
s = np.sin(t); c = np.cos(t)
|
413 |
+
v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
|
414 |
+
return v.astype('float32')
|
415 |
+
|
416 |
+
#----------------------------------------------------------------------------
|
gui_utils/glfw_window.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import time
|
10 |
+
import glfw
|
11 |
+
import OpenGL.GL as gl
|
12 |
+
from . import gl_utils
|
13 |
+
|
14 |
+
#----------------------------------------------------------------------------
|
15 |
+
|
16 |
+
class GlfwWindow: # pylint: disable=too-many-public-methods
|
17 |
+
def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
|
18 |
+
self._glfw_window = None
|
19 |
+
self._drawing_frame = False
|
20 |
+
self._frame_start_time = None
|
21 |
+
self._frame_delta = 0
|
22 |
+
self._fps_limit = None
|
23 |
+
self._vsync = None
|
24 |
+
self._skip_frames = 0
|
25 |
+
self._deferred_show = deferred_show
|
26 |
+
self._close_on_esc = close_on_esc
|
27 |
+
self._esc_pressed = False
|
28 |
+
self._drag_and_drop_paths = None
|
29 |
+
self._capture_next_frame = False
|
30 |
+
self._captured_frame = None
|
31 |
+
|
32 |
+
# Create window.
|
33 |
+
glfw.init()
|
34 |
+
glfw.window_hint(glfw.VISIBLE, False)
|
35 |
+
self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
|
36 |
+
self._attach_glfw_callbacks()
|
37 |
+
self.make_context_current()
|
38 |
+
|
39 |
+
# Adjust window.
|
40 |
+
self.set_vsync(False)
|
41 |
+
self.set_window_size(window_width, window_height)
|
42 |
+
if not self._deferred_show:
|
43 |
+
glfw.show_window(self._glfw_window)
|
44 |
+
|
45 |
+
def close(self):
|
46 |
+
if self._drawing_frame:
|
47 |
+
self.end_frame()
|
48 |
+
if self._glfw_window is not None:
|
49 |
+
glfw.destroy_window(self._glfw_window)
|
50 |
+
self._glfw_window = None
|
51 |
+
#glfw.terminate() # Commented out to play it nice with other glfw clients.
|
52 |
+
|
53 |
+
def __del__(self):
|
54 |
+
try:
|
55 |
+
self.close()
|
56 |
+
except:
|
57 |
+
pass
|
58 |
+
|
59 |
+
@property
|
60 |
+
def window_width(self):
|
61 |
+
return self.content_width
|
62 |
+
|
63 |
+
@property
|
64 |
+
def window_height(self):
|
65 |
+
return self.content_height + self.title_bar_height
|
66 |
+
|
67 |
+
@property
|
68 |
+
def content_width(self):
|
69 |
+
width, _height = glfw.get_window_size(self._glfw_window)
|
70 |
+
return width
|
71 |
+
|
72 |
+
@property
|
73 |
+
def content_height(self):
|
74 |
+
_width, height = glfw.get_window_size(self._glfw_window)
|
75 |
+
return height
|
76 |
+
|
77 |
+
@property
|
78 |
+
def title_bar_height(self):
|
79 |
+
_left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
|
80 |
+
return top
|
81 |
+
|
82 |
+
@property
|
83 |
+
def monitor_width(self):
|
84 |
+
_, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
|
85 |
+
return width
|
86 |
+
|
87 |
+
@property
|
88 |
+
def monitor_height(self):
|
89 |
+
_, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
|
90 |
+
return height
|
91 |
+
|
92 |
+
@property
|
93 |
+
def frame_delta(self):
|
94 |
+
return self._frame_delta
|
95 |
+
|
96 |
+
def set_title(self, title):
|
97 |
+
glfw.set_window_title(self._glfw_window, title)
|
98 |
+
|
99 |
+
def set_window_size(self, width, height):
|
100 |
+
width = min(width, self.monitor_width)
|
101 |
+
height = min(height, self.monitor_height)
|
102 |
+
glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
|
103 |
+
if width == self.monitor_width and height == self.monitor_height:
|
104 |
+
self.maximize()
|
105 |
+
|
106 |
+
def set_content_size(self, width, height):
|
107 |
+
self.set_window_size(width, height + self.title_bar_height)
|
108 |
+
|
109 |
+
def maximize(self):
|
110 |
+
glfw.maximize_window(self._glfw_window)
|
111 |
+
|
112 |
+
def set_position(self, x, y):
|
113 |
+
glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
|
114 |
+
|
115 |
+
def center(self):
|
116 |
+
self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
|
117 |
+
|
118 |
+
def set_vsync(self, vsync):
|
119 |
+
vsync = bool(vsync)
|
120 |
+
if vsync != self._vsync:
|
121 |
+
glfw.swap_interval(1 if vsync else 0)
|
122 |
+
self._vsync = vsync
|
123 |
+
|
124 |
+
def set_fps_limit(self, fps_limit):
|
125 |
+
self._fps_limit = int(fps_limit)
|
126 |
+
|
127 |
+
def should_close(self):
|
128 |
+
return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
|
129 |
+
|
130 |
+
def skip_frame(self):
|
131 |
+
self.skip_frames(1)
|
132 |
+
|
133 |
+
def skip_frames(self, num): # Do not update window for the next N frames.
|
134 |
+
self._skip_frames = max(self._skip_frames, int(num))
|
135 |
+
|
136 |
+
def is_skipping_frames(self):
|
137 |
+
return self._skip_frames > 0
|
138 |
+
|
139 |
+
def capture_next_frame(self):
|
140 |
+
self._capture_next_frame = True
|
141 |
+
|
142 |
+
def pop_captured_frame(self):
|
143 |
+
frame = self._captured_frame
|
144 |
+
self._captured_frame = None
|
145 |
+
return frame
|
146 |
+
|
147 |
+
def pop_drag_and_drop_paths(self):
|
148 |
+
paths = self._drag_and_drop_paths
|
149 |
+
self._drag_and_drop_paths = None
|
150 |
+
return paths
|
151 |
+
|
152 |
+
def draw_frame(self): # To be overridden by subclass.
|
153 |
+
self.begin_frame()
|
154 |
+
# Rendering code goes here.
|
155 |
+
self.end_frame()
|
156 |
+
|
157 |
+
def make_context_current(self):
|
158 |
+
if self._glfw_window is not None:
|
159 |
+
glfw.make_context_current(self._glfw_window)
|
160 |
+
|
161 |
+
def begin_frame(self):
|
162 |
+
# End previous frame.
|
163 |
+
if self._drawing_frame:
|
164 |
+
self.end_frame()
|
165 |
+
|
166 |
+
# Apply FPS limit.
|
167 |
+
if self._frame_start_time is not None and self._fps_limit is not None:
|
168 |
+
delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
|
169 |
+
if delay > 0:
|
170 |
+
time.sleep(delay)
|
171 |
+
cur_time = time.perf_counter()
|
172 |
+
if self._frame_start_time is not None:
|
173 |
+
self._frame_delta = cur_time - self._frame_start_time
|
174 |
+
self._frame_start_time = cur_time
|
175 |
+
|
176 |
+
# Process events.
|
177 |
+
glfw.poll_events()
|
178 |
+
|
179 |
+
# Begin frame.
|
180 |
+
self._drawing_frame = True
|
181 |
+
self.make_context_current()
|
182 |
+
|
183 |
+
# Initialize GL state.
|
184 |
+
gl.glViewport(0, 0, self.content_width, self.content_height)
|
185 |
+
gl.glMatrixMode(gl.GL_PROJECTION)
|
186 |
+
gl.glLoadIdentity()
|
187 |
+
gl.glTranslate(-1, 1, 0)
|
188 |
+
gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
|
189 |
+
gl.glMatrixMode(gl.GL_MODELVIEW)
|
190 |
+
gl.glLoadIdentity()
|
191 |
+
gl.glEnable(gl.GL_BLEND)
|
192 |
+
gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
|
193 |
+
|
194 |
+
# Clear.
|
195 |
+
gl.glClearColor(0, 0, 0, 1)
|
196 |
+
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
|
197 |
+
|
198 |
+
def end_frame(self):
|
199 |
+
assert self._drawing_frame
|
200 |
+
self._drawing_frame = False
|
201 |
+
|
202 |
+
# Skip frames if requested.
|
203 |
+
if self._skip_frames > 0:
|
204 |
+
self._skip_frames -= 1
|
205 |
+
return
|
206 |
+
|
207 |
+
# Capture frame if requested.
|
208 |
+
if self._capture_next_frame:
|
209 |
+
self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
|
210 |
+
self._capture_next_frame = False
|
211 |
+
|
212 |
+
# Update window.
|
213 |
+
if self._deferred_show:
|
214 |
+
glfw.show_window(self._glfw_window)
|
215 |
+
self._deferred_show = False
|
216 |
+
glfw.swap_buffers(self._glfw_window)
|
217 |
+
|
218 |
+
def _attach_glfw_callbacks(self):
|
219 |
+
glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
|
220 |
+
glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
|
221 |
+
|
222 |
+
def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
|
223 |
+
if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
|
224 |
+
self._esc_pressed = True
|
225 |
+
|
226 |
+
def _glfw_drop_callback(self, _window, paths):
|
227 |
+
self._drag_and_drop_paths = paths
|
228 |
+
|
229 |
+
#----------------------------------------------------------------------------
|
gui_utils/imgui_utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import contextlib
|
10 |
+
import imgui
|
11 |
+
|
12 |
+
#----------------------------------------------------------------------------
|
13 |
+
|
14 |
+
def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
|
15 |
+
s = imgui.get_style()
|
16 |
+
s.window_padding = [spacing, spacing]
|
17 |
+
s.item_spacing = [spacing, spacing]
|
18 |
+
s.item_inner_spacing = [spacing, spacing]
|
19 |
+
s.columns_min_spacing = spacing
|
20 |
+
s.indent_spacing = indent
|
21 |
+
s.scrollbar_size = scrollbar
|
22 |
+
s.frame_padding = [4, 3]
|
23 |
+
s.window_border_size = 1
|
24 |
+
s.child_border_size = 1
|
25 |
+
s.popup_border_size = 1
|
26 |
+
s.frame_border_size = 1
|
27 |
+
s.window_rounding = 0
|
28 |
+
s.child_rounding = 0
|
29 |
+
s.popup_rounding = 3
|
30 |
+
s.frame_rounding = 3
|
31 |
+
s.scrollbar_rounding = 3
|
32 |
+
s.grab_rounding = 3
|
33 |
+
|
34 |
+
getattr(imgui, f'style_colors_{color_scheme}')(s)
|
35 |
+
c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
|
36 |
+
c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
|
37 |
+
s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
|
38 |
+
|
39 |
+
#----------------------------------------------------------------------------
|
40 |
+
|
41 |
+
@contextlib.contextmanager
|
42 |
+
def grayed_out(cond=True):
|
43 |
+
if cond:
|
44 |
+
s = imgui.get_style()
|
45 |
+
text = s.colors[imgui.COLOR_TEXT_DISABLED]
|
46 |
+
grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
|
47 |
+
back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
|
48 |
+
imgui.push_style_color(imgui.COLOR_TEXT, *text)
|
49 |
+
imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
|
50 |
+
imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
|
51 |
+
imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
|
52 |
+
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
|
53 |
+
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
|
54 |
+
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
|
55 |
+
imgui.push_style_color(imgui.COLOR_BUTTON, *back)
|
56 |
+
imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
|
57 |
+
imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
|
58 |
+
imgui.push_style_color(imgui.COLOR_HEADER, *back)
|
59 |
+
imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
|
60 |
+
imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
|
61 |
+
imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
|
62 |
+
yield
|
63 |
+
imgui.pop_style_color(14)
|
64 |
+
else:
|
65 |
+
yield
|
66 |
+
|
67 |
+
#----------------------------------------------------------------------------
|
68 |
+
|
69 |
+
@contextlib.contextmanager
|
70 |
+
def item_width(width=None):
|
71 |
+
if width is not None:
|
72 |
+
imgui.push_item_width(width)
|
73 |
+
yield
|
74 |
+
imgui.pop_item_width()
|
75 |
+
else:
|
76 |
+
yield
|
77 |
+
|
78 |
+
#----------------------------------------------------------------------------
|
79 |
+
|
80 |
+
def scoped_by_object_id(method):
|
81 |
+
def decorator(self, *args, **kwargs):
|
82 |
+
imgui.push_id(str(id(self)))
|
83 |
+
res = method(self, *args, **kwargs)
|
84 |
+
imgui.pop_id()
|
85 |
+
return res
|
86 |
+
return decorator
|
87 |
+
|
88 |
+
#----------------------------------------------------------------------------
|
89 |
+
|
90 |
+
def button(label, width=0, enabled=True):
|
91 |
+
with grayed_out(not enabled):
|
92 |
+
clicked = imgui.button(label, width=width)
|
93 |
+
clicked = clicked and enabled
|
94 |
+
return clicked
|
95 |
+
|
96 |
+
#----------------------------------------------------------------------------
|
97 |
+
|
98 |
+
def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
|
99 |
+
expanded = False
|
100 |
+
if show:
|
101 |
+
if default:
|
102 |
+
flags |= imgui.TREE_NODE_DEFAULT_OPEN
|
103 |
+
if not enabled:
|
104 |
+
flags |= imgui.TREE_NODE_LEAF
|
105 |
+
with grayed_out(not enabled):
|
106 |
+
expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
|
107 |
+
expanded = expanded and enabled
|
108 |
+
return expanded, visible
|
109 |
+
|
110 |
+
#----------------------------------------------------------------------------
|
111 |
+
|
112 |
+
def popup_button(label, width=0, enabled=True):
|
113 |
+
if button(label, width, enabled):
|
114 |
+
imgui.open_popup(label)
|
115 |
+
opened = imgui.begin_popup(label)
|
116 |
+
return opened
|
117 |
+
|
118 |
+
#----------------------------------------------------------------------------
|
119 |
+
|
120 |
+
def input_text(label, value, buffer_length, flags, width=None, help_text=''):
|
121 |
+
old_value = value
|
122 |
+
color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
|
123 |
+
if value == '':
|
124 |
+
color[-1] *= 0.5
|
125 |
+
with item_width(width):
|
126 |
+
imgui.push_style_color(imgui.COLOR_TEXT, *color)
|
127 |
+
value = value if value != '' else help_text
|
128 |
+
changed, value = imgui.input_text(label, value, buffer_length, flags)
|
129 |
+
value = value if value != help_text else ''
|
130 |
+
imgui.pop_style_color(1)
|
131 |
+
if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
|
132 |
+
changed = (value != old_value)
|
133 |
+
return changed, value
|
134 |
+
|
135 |
+
#----------------------------------------------------------------------------
|
136 |
+
|
137 |
+
def drag_previous_control(enabled=True):
|
138 |
+
dragging = False
|
139 |
+
dx = 0
|
140 |
+
dy = 0
|
141 |
+
if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
|
142 |
+
if enabled:
|
143 |
+
dragging = True
|
144 |
+
dx, dy = imgui.get_mouse_drag_delta()
|
145 |
+
imgui.reset_mouse_drag_delta()
|
146 |
+
imgui.end_drag_drop_source()
|
147 |
+
return dragging, dx, dy
|
148 |
+
|
149 |
+
#----------------------------------------------------------------------------
|
150 |
+
|
151 |
+
def drag_button(label, width=0, enabled=True):
|
152 |
+
clicked = button(label, width=width, enabled=enabled)
|
153 |
+
dragging, dx, dy = drag_previous_control(enabled=enabled)
|
154 |
+
return clicked, dragging, dx, dy
|
155 |
+
|
156 |
+
#----------------------------------------------------------------------------
|
157 |
+
|
158 |
+
def drag_hidden_window(label, x, y, width, height, enabled=True):
|
159 |
+
imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
|
160 |
+
imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
|
161 |
+
imgui.set_next_window_position(x, y)
|
162 |
+
imgui.set_next_window_size(width, height)
|
163 |
+
imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
|
164 |
+
dragging, dx, dy = drag_previous_control(enabled=enabled)
|
165 |
+
imgui.end()
|
166 |
+
imgui.pop_style_color(2)
|
167 |
+
return dragging, dx, dy
|
168 |
+
|
169 |
+
#----------------------------------------------------------------------------
|
170 |
+
|
171 |
+
def click_hidden_window(label, x, y, width, height, img_w, img_h, enabled=True):
|
172 |
+
imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
|
173 |
+
imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
|
174 |
+
imgui.set_next_window_position(x, y)
|
175 |
+
imgui.set_next_window_size(width, height)
|
176 |
+
imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
|
177 |
+
clicked, down = False, False
|
178 |
+
img_x, img_y = 0, 0
|
179 |
+
if imgui.is_mouse_down():
|
180 |
+
posx, posy = imgui.get_mouse_pos()
|
181 |
+
if posx >= x and posx < x + width and posy >= y and posy < y + height:
|
182 |
+
if imgui.is_mouse_clicked():
|
183 |
+
clicked = True
|
184 |
+
down = True
|
185 |
+
img_x = round((posx - x) / (width - 1) * (img_w - 1))
|
186 |
+
img_y = round((posy - y) / (height - 1) * (img_h - 1))
|
187 |
+
imgui.end()
|
188 |
+
imgui.pop_style_color(2)
|
189 |
+
return clicked, down, img_x, img_y
|
190 |
+
|
191 |
+
#----------------------------------------------------------------------------
|
gui_utils/imgui_window.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import imgui
|
11 |
+
import imgui.integrations.glfw
|
12 |
+
|
13 |
+
from . import glfw_window
|
14 |
+
from . import imgui_utils
|
15 |
+
from . import text_utils
|
16 |
+
|
17 |
+
#----------------------------------------------------------------------------
|
18 |
+
|
19 |
+
class ImguiWindow(glfw_window.GlfwWindow):
|
20 |
+
def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
|
21 |
+
if font is None:
|
22 |
+
font = text_utils.get_default_font()
|
23 |
+
font_sizes = {int(size) for size in font_sizes}
|
24 |
+
super().__init__(title=title, **glfw_kwargs)
|
25 |
+
|
26 |
+
# Init fields.
|
27 |
+
self._imgui_context = None
|
28 |
+
self._imgui_renderer = None
|
29 |
+
self._imgui_fonts = None
|
30 |
+
self._cur_font_size = max(font_sizes)
|
31 |
+
|
32 |
+
# Delete leftover imgui.ini to avoid unexpected behavior.
|
33 |
+
if os.path.isfile('imgui.ini'):
|
34 |
+
os.remove('imgui.ini')
|
35 |
+
|
36 |
+
# Init ImGui.
|
37 |
+
self._imgui_context = imgui.create_context()
|
38 |
+
self._imgui_renderer = _GlfwRenderer(self._glfw_window)
|
39 |
+
self._attach_glfw_callbacks()
|
40 |
+
imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
|
41 |
+
imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
|
42 |
+
self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
|
43 |
+
self._imgui_renderer.refresh_font_texture()
|
44 |
+
|
45 |
+
def close(self):
|
46 |
+
self.make_context_current()
|
47 |
+
self._imgui_fonts = None
|
48 |
+
if self._imgui_renderer is not None:
|
49 |
+
self._imgui_renderer.shutdown()
|
50 |
+
self._imgui_renderer = None
|
51 |
+
if self._imgui_context is not None:
|
52 |
+
#imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
|
53 |
+
self._imgui_context = None
|
54 |
+
super().close()
|
55 |
+
|
56 |
+
def _glfw_key_callback(self, *args):
|
57 |
+
super()._glfw_key_callback(*args)
|
58 |
+
self._imgui_renderer.keyboard_callback(*args)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def font_size(self):
|
62 |
+
return self._cur_font_size
|
63 |
+
|
64 |
+
@property
|
65 |
+
def spacing(self):
|
66 |
+
return round(self._cur_font_size * 0.4)
|
67 |
+
|
68 |
+
def set_font_size(self, target): # Applied on next frame.
|
69 |
+
self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
|
70 |
+
|
71 |
+
def begin_frame(self):
|
72 |
+
# Begin glfw frame.
|
73 |
+
super().begin_frame()
|
74 |
+
|
75 |
+
# Process imgui events.
|
76 |
+
self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
|
77 |
+
if self.content_width > 0 and self.content_height > 0:
|
78 |
+
self._imgui_renderer.process_inputs()
|
79 |
+
|
80 |
+
# Begin imgui frame.
|
81 |
+
imgui.new_frame()
|
82 |
+
imgui.push_font(self._imgui_fonts[self._cur_font_size])
|
83 |
+
imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
|
84 |
+
|
85 |
+
def end_frame(self):
|
86 |
+
imgui.pop_font()
|
87 |
+
imgui.render()
|
88 |
+
imgui.end_frame()
|
89 |
+
self._imgui_renderer.render(imgui.get_draw_data())
|
90 |
+
super().end_frame()
|
91 |
+
|
92 |
+
#----------------------------------------------------------------------------
|
93 |
+
# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
|
94 |
+
|
95 |
+
class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
|
96 |
+
def __init__(self, *args, **kwargs):
|
97 |
+
super().__init__(*args, **kwargs)
|
98 |
+
self.mouse_wheel_multiplier = 1
|
99 |
+
|
100 |
+
def scroll_callback(self, window, x_offset, y_offset):
|
101 |
+
self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
|
102 |
+
|
103 |
+
#----------------------------------------------------------------------------
|
gui_utils/text_utils.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import functools
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import dnnlib
|
13 |
+
import numpy as np
|
14 |
+
import PIL.Image
|
15 |
+
import PIL.ImageFont
|
16 |
+
import scipy.ndimage
|
17 |
+
|
18 |
+
from . import gl_utils
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
def get_default_font():
|
23 |
+
url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
|
24 |
+
return dnnlib.util.open_url(url, return_filename=True)
|
25 |
+
|
26 |
+
#----------------------------------------------------------------------------
|
27 |
+
|
28 |
+
@functools.lru_cache(maxsize=None)
|
29 |
+
def get_pil_font(font=None, size=32):
|
30 |
+
if font is None:
|
31 |
+
font = get_default_font()
|
32 |
+
return PIL.ImageFont.truetype(font=font, size=size)
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
def get_array(string, *, dropshadow_radius: int=None, **kwargs):
|
37 |
+
if dropshadow_radius is not None:
|
38 |
+
offset_x = int(np.ceil(dropshadow_radius*2/3))
|
39 |
+
offset_y = int(np.ceil(dropshadow_radius*2/3))
|
40 |
+
return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
|
41 |
+
else:
|
42 |
+
return _get_array_priv(string, **kwargs)
|
43 |
+
|
44 |
+
@functools.lru_cache(maxsize=10000)
|
45 |
+
def _get_array_priv(
|
46 |
+
string: str, *,
|
47 |
+
size: int = 32,
|
48 |
+
max_width: Optional[int]=None,
|
49 |
+
max_height: Optional[int]=None,
|
50 |
+
min_size=10,
|
51 |
+
shrink_coef=0.8,
|
52 |
+
dropshadow_radius: int=None,
|
53 |
+
offset_x: int=None,
|
54 |
+
offset_y: int=None,
|
55 |
+
**kwargs
|
56 |
+
):
|
57 |
+
cur_size = size
|
58 |
+
array = None
|
59 |
+
while True:
|
60 |
+
if dropshadow_radius is not None:
|
61 |
+
# separate implementation for dropshadow text rendering
|
62 |
+
array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
|
63 |
+
else:
|
64 |
+
array = _get_array_impl(string, size=cur_size, **kwargs)
|
65 |
+
height, width, _ = array.shape
|
66 |
+
if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
|
67 |
+
break
|
68 |
+
cur_size = max(int(cur_size * shrink_coef), min_size)
|
69 |
+
return array
|
70 |
+
|
71 |
+
#----------------------------------------------------------------------------
|
72 |
+
|
73 |
+
@functools.lru_cache(maxsize=10000)
|
74 |
+
def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
|
75 |
+
pil_font = get_pil_font(font=font, size=size)
|
76 |
+
lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
|
77 |
+
lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
|
78 |
+
width = max(line.shape[1] for line in lines)
|
79 |
+
lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
|
80 |
+
line_spacing = line_pad if line_pad is not None else size // 2
|
81 |
+
lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
|
82 |
+
mask = np.concatenate(lines, axis=0)
|
83 |
+
alpha = mask
|
84 |
+
if outline > 0:
|
85 |
+
mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
|
86 |
+
alpha = mask.astype(np.float32) / 255
|
87 |
+
alpha = scipy.ndimage.gaussian_filter(alpha, outline)
|
88 |
+
alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
|
89 |
+
alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
|
90 |
+
alpha = np.maximum(alpha, mask)
|
91 |
+
return np.stack([mask, alpha], axis=-1)
|
92 |
+
|
93 |
+
#----------------------------------------------------------------------------
|
94 |
+
|
95 |
+
@functools.lru_cache(maxsize=10000)
|
96 |
+
def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
|
97 |
+
assert (offset_x > 0) and (offset_y > 0)
|
98 |
+
pil_font = get_pil_font(font=font, size=size)
|
99 |
+
lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
|
100 |
+
lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
|
101 |
+
width = max(line.shape[1] for line in lines)
|
102 |
+
lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
|
103 |
+
line_spacing = line_pad if line_pad is not None else size // 2
|
104 |
+
lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
|
105 |
+
mask = np.concatenate(lines, axis=0)
|
106 |
+
alpha = mask
|
107 |
+
|
108 |
+
mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
|
109 |
+
alpha = mask.astype(np.float32) / 255
|
110 |
+
alpha = scipy.ndimage.gaussian_filter(alpha, radius)
|
111 |
+
alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
|
112 |
+
alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
|
113 |
+
alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
|
114 |
+
alpha = np.maximum(alpha, mask)
|
115 |
+
return np.stack([mask, alpha], axis=-1)
|
116 |
+
|
117 |
+
#----------------------------------------------------------------------------
|
118 |
+
|
119 |
+
@functools.lru_cache(maxsize=10000)
|
120 |
+
def get_texture(string, bilinear=True, mipmap=True, **kwargs):
|
121 |
+
return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
|
122 |
+
|
123 |
+
#----------------------------------------------------------------------------
|
legacy.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Converting legacy network pickle into the new format."""
|
10 |
+
|
11 |
+
import click
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
import copy
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
from torch_utils import misc
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
def load_network_pkl(f, force_fp16=False):
|
23 |
+
data = _LegacyUnpickler(f).load()
|
24 |
+
|
25 |
+
# Legacy TensorFlow pickle => convert.
|
26 |
+
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
|
27 |
+
tf_G, tf_D, tf_Gs = data
|
28 |
+
G = convert_tf_generator(tf_G)
|
29 |
+
D = convert_tf_discriminator(tf_D)
|
30 |
+
G_ema = convert_tf_generator(tf_Gs)
|
31 |
+
data = dict(G=G, D=D, G_ema=G_ema)
|
32 |
+
|
33 |
+
# Add missing fields.
|
34 |
+
if 'training_set_kwargs' not in data:
|
35 |
+
data['training_set_kwargs'] = None
|
36 |
+
if 'augment_pipe' not in data:
|
37 |
+
data['augment_pipe'] = None
|
38 |
+
|
39 |
+
# Validate contents.
|
40 |
+
assert isinstance(data['G'], torch.nn.Module)
|
41 |
+
assert isinstance(data['D'], torch.nn.Module)
|
42 |
+
assert isinstance(data['G_ema'], torch.nn.Module)
|
43 |
+
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
|
44 |
+
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
|
45 |
+
|
46 |
+
# Force FP16.
|
47 |
+
if force_fp16:
|
48 |
+
for key in ['G', 'D', 'G_ema']:
|
49 |
+
old = data[key]
|
50 |
+
kwargs = copy.deepcopy(old.init_kwargs)
|
51 |
+
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
|
52 |
+
fp16_kwargs.num_fp16_res = 4
|
53 |
+
fp16_kwargs.conv_clamp = 256
|
54 |
+
if kwargs != old.init_kwargs:
|
55 |
+
new = type(old)(**kwargs).eval().requires_grad_(False)
|
56 |
+
misc.copy_params_and_buffers(old, new, require_all=True)
|
57 |
+
data[key] = new
|
58 |
+
return data
|
59 |
+
|
60 |
+
#----------------------------------------------------------------------------
|
61 |
+
|
62 |
+
class _TFNetworkStub(dnnlib.EasyDict):
|
63 |
+
pass
|
64 |
+
|
65 |
+
class _LegacyUnpickler(pickle.Unpickler):
|
66 |
+
def find_class(self, module, name):
|
67 |
+
if module == 'dnnlib.tflib.network' and name == 'Network':
|
68 |
+
return _TFNetworkStub
|
69 |
+
return super().find_class(module, name)
|
70 |
+
|
71 |
+
#----------------------------------------------------------------------------
|
72 |
+
|
73 |
+
def _collect_tf_params(tf_net):
|
74 |
+
# pylint: disable=protected-access
|
75 |
+
tf_params = dict()
|
76 |
+
def recurse(prefix, tf_net):
|
77 |
+
for name, value in tf_net.variables:
|
78 |
+
tf_params[prefix + name] = value
|
79 |
+
for name, comp in tf_net.components.items():
|
80 |
+
recurse(prefix + name + '/', comp)
|
81 |
+
recurse('', tf_net)
|
82 |
+
return tf_params
|
83 |
+
|
84 |
+
#----------------------------------------------------------------------------
|
85 |
+
|
86 |
+
def _populate_module_params(module, *patterns):
|
87 |
+
for name, tensor in misc.named_params_and_buffers(module):
|
88 |
+
found = False
|
89 |
+
value = None
|
90 |
+
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
|
91 |
+
match = re.fullmatch(pattern, name)
|
92 |
+
if match:
|
93 |
+
found = True
|
94 |
+
if value_fn is not None:
|
95 |
+
value = value_fn(*match.groups())
|
96 |
+
break
|
97 |
+
try:
|
98 |
+
assert found
|
99 |
+
if value is not None:
|
100 |
+
tensor.copy_(torch.from_numpy(np.array(value)))
|
101 |
+
except:
|
102 |
+
print(name, list(tensor.shape))
|
103 |
+
raise
|
104 |
+
|
105 |
+
#----------------------------------------------------------------------------
|
106 |
+
|
107 |
+
def convert_tf_generator(tf_G):
|
108 |
+
if tf_G.version < 4:
|
109 |
+
raise ValueError('TensorFlow pickle version too low')
|
110 |
+
|
111 |
+
# Collect kwargs.
|
112 |
+
tf_kwargs = tf_G.static_kwargs
|
113 |
+
known_kwargs = set()
|
114 |
+
def kwarg(tf_name, default=None, none=None):
|
115 |
+
known_kwargs.add(tf_name)
|
116 |
+
val = tf_kwargs.get(tf_name, default)
|
117 |
+
return val if val is not None else none
|
118 |
+
|
119 |
+
# Convert kwargs.
|
120 |
+
from training import networks_stylegan2
|
121 |
+
network_class = networks_stylegan2.Generator
|
122 |
+
kwargs = dnnlib.EasyDict(
|
123 |
+
z_dim = kwarg('latent_size', 512),
|
124 |
+
c_dim = kwarg('label_size', 0),
|
125 |
+
w_dim = kwarg('dlatent_size', 512),
|
126 |
+
img_resolution = kwarg('resolution', 1024),
|
127 |
+
img_channels = kwarg('num_channels', 3),
|
128 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
129 |
+
channel_max = kwarg('fmap_max', 512),
|
130 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
131 |
+
conv_clamp = kwarg('conv_clamp', None),
|
132 |
+
architecture = kwarg('architecture', 'skip'),
|
133 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
134 |
+
use_noise = kwarg('use_noise', True),
|
135 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
136 |
+
mapping_kwargs = dnnlib.EasyDict(
|
137 |
+
num_layers = kwarg('mapping_layers', 8),
|
138 |
+
embed_features = kwarg('label_fmaps', None),
|
139 |
+
layer_features = kwarg('mapping_fmaps', None),
|
140 |
+
activation = kwarg('mapping_nonlinearity', 'lrelu'),
|
141 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.01),
|
142 |
+
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
|
143 |
+
),
|
144 |
+
)
|
145 |
+
|
146 |
+
# Check for unknown kwargs.
|
147 |
+
kwarg('truncation_psi')
|
148 |
+
kwarg('truncation_cutoff')
|
149 |
+
kwarg('style_mixing_prob')
|
150 |
+
kwarg('structure')
|
151 |
+
kwarg('conditioning')
|
152 |
+
kwarg('fused_modconv')
|
153 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
154 |
+
if len(unknown_kwargs) > 0:
|
155 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
156 |
+
|
157 |
+
# Collect params.
|
158 |
+
tf_params = _collect_tf_params(tf_G)
|
159 |
+
for name, value in list(tf_params.items()):
|
160 |
+
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
|
161 |
+
if match:
|
162 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
163 |
+
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
|
164 |
+
kwargs.synthesis.kwargs.architecture = 'orig'
|
165 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
166 |
+
|
167 |
+
# Convert params.
|
168 |
+
G = network_class(**kwargs).eval().requires_grad_(False)
|
169 |
+
# pylint: disable=unnecessary-lambda
|
170 |
+
# pylint: disable=f-string-without-interpolation
|
171 |
+
_populate_module_params(G,
|
172 |
+
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
|
173 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
|
174 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
|
175 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
|
176 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
|
177 |
+
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
|
178 |
+
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
179 |
+
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
|
180 |
+
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
|
181 |
+
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
|
182 |
+
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
|
183 |
+
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
|
184 |
+
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
185 |
+
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
|
186 |
+
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
|
187 |
+
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
|
188 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
|
189 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
|
190 |
+
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
|
191 |
+
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
|
192 |
+
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
|
193 |
+
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
|
194 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
|
195 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
|
196 |
+
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
|
197 |
+
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
|
198 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
|
199 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
|
200 |
+
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
201 |
+
r'.*\.resample_filter', None,
|
202 |
+
r'.*\.act_filter', None,
|
203 |
+
)
|
204 |
+
return G
|
205 |
+
|
206 |
+
#----------------------------------------------------------------------------
|
207 |
+
|
208 |
+
def convert_tf_discriminator(tf_D):
|
209 |
+
if tf_D.version < 4:
|
210 |
+
raise ValueError('TensorFlow pickle version too low')
|
211 |
+
|
212 |
+
# Collect kwargs.
|
213 |
+
tf_kwargs = tf_D.static_kwargs
|
214 |
+
known_kwargs = set()
|
215 |
+
def kwarg(tf_name, default=None):
|
216 |
+
known_kwargs.add(tf_name)
|
217 |
+
return tf_kwargs.get(tf_name, default)
|
218 |
+
|
219 |
+
# Convert kwargs.
|
220 |
+
kwargs = dnnlib.EasyDict(
|
221 |
+
c_dim = kwarg('label_size', 0),
|
222 |
+
img_resolution = kwarg('resolution', 1024),
|
223 |
+
img_channels = kwarg('num_channels', 3),
|
224 |
+
architecture = kwarg('architecture', 'resnet'),
|
225 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
226 |
+
channel_max = kwarg('fmap_max', 512),
|
227 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
228 |
+
conv_clamp = kwarg('conv_clamp', None),
|
229 |
+
cmap_dim = kwarg('mapping_fmaps', None),
|
230 |
+
block_kwargs = dnnlib.EasyDict(
|
231 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
232 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
233 |
+
freeze_layers = kwarg('freeze_layers', 0),
|
234 |
+
),
|
235 |
+
mapping_kwargs = dnnlib.EasyDict(
|
236 |
+
num_layers = kwarg('mapping_layers', 0),
|
237 |
+
embed_features = kwarg('mapping_fmaps', None),
|
238 |
+
layer_features = kwarg('mapping_fmaps', None),
|
239 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
240 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.1),
|
241 |
+
),
|
242 |
+
epilogue_kwargs = dnnlib.EasyDict(
|
243 |
+
mbstd_group_size = kwarg('mbstd_group_size', None),
|
244 |
+
mbstd_num_channels = kwarg('mbstd_num_features', 1),
|
245 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
246 |
+
),
|
247 |
+
)
|
248 |
+
|
249 |
+
# Check for unknown kwargs.
|
250 |
+
kwarg('structure')
|
251 |
+
kwarg('conditioning')
|
252 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
253 |
+
if len(unknown_kwargs) > 0:
|
254 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
255 |
+
|
256 |
+
# Collect params.
|
257 |
+
tf_params = _collect_tf_params(tf_D)
|
258 |
+
for name, value in list(tf_params.items()):
|
259 |
+
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
|
260 |
+
if match:
|
261 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
262 |
+
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
|
263 |
+
kwargs.architecture = 'orig'
|
264 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
265 |
+
|
266 |
+
# Convert params.
|
267 |
+
from training import networks_stylegan2
|
268 |
+
D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
|
269 |
+
# pylint: disable=unnecessary-lambda
|
270 |
+
# pylint: disable=f-string-without-interpolation
|
271 |
+
_populate_module_params(D,
|
272 |
+
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
|
273 |
+
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
|
274 |
+
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
|
275 |
+
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
|
276 |
+
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
|
277 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
|
278 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
|
279 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
|
280 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
|
281 |
+
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
282 |
+
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
|
283 |
+
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
|
284 |
+
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
|
285 |
+
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
|
286 |
+
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
|
287 |
+
r'.*\.resample_filter', None,
|
288 |
+
)
|
289 |
+
return D
|
290 |
+
|
291 |
+
#----------------------------------------------------------------------------
|
292 |
+
|
293 |
+
@click.command()
|
294 |
+
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
|
295 |
+
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
|
296 |
+
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
|
297 |
+
def convert_network_pickle(source, dest, force_fp16):
|
298 |
+
"""Convert legacy network pickle into the native PyTorch format.
|
299 |
+
|
300 |
+
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
|
301 |
+
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
|
302 |
+
|
303 |
+
Example:
|
304 |
+
|
305 |
+
\b
|
306 |
+
python legacy.py \\
|
307 |
+
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
|
308 |
+
--dest=stylegan2-cat-config-f.pkl
|
309 |
+
"""
|
310 |
+
print(f'Loading "{source}"...')
|
311 |
+
with dnnlib.util.open_url(source) as f:
|
312 |
+
data = load_network_pkl(f, force_fp16=force_fp16)
|
313 |
+
print(f'Saving "{dest}"...')
|
314 |
+
with open(dest, 'wb') as f:
|
315 |
+
pickle.dump(data, f)
|
316 |
+
print('Done.')
|
317 |
+
|
318 |
+
#----------------------------------------------------------------------------
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
convert_network_pickle() # pylint: disable=no-value-for-parameter
|
322 |
+
|
323 |
+
#----------------------------------------------------------------------------
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
scipy==1.11.0
|
3 |
+
Ninja==1.10.2
|
4 |
+
gradio>=3.35.2
|
5 |
+
imageio-ffmpeg>=0.4.3
|
6 |
+
huggingface_hub
|
7 |
+
hf_transfer
|
8 |
+
pyopengl
|
9 |
+
imgui
|
10 |
+
glfw==2.6.1
|
11 |
+
pillow>=9.4.0
|
12 |
+
torchvision>=0.15.2
|
13 |
+
imageio>=2.9.0
|
scripts/download_model.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
def download_file(url: str, filename: str, download_dir: str):
|
8 |
+
"""Download a file if it does not already exist."""
|
9 |
+
|
10 |
+
try:
|
11 |
+
filepath = os.path.join(download_dir, filename)
|
12 |
+
content_length = int(requests.head(url).headers.get("content-length", 0))
|
13 |
+
|
14 |
+
# If file already exists and size matches, skip download
|
15 |
+
if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length:
|
16 |
+
print(f"{filepath} already exists. Skipping download.")
|
17 |
+
return
|
18 |
+
if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length:
|
19 |
+
print(f"{filepath} already exists but size does not match. Redownloading.")
|
20 |
+
else:
|
21 |
+
print(f"Downloading {filename} from {url}")
|
22 |
+
|
23 |
+
# Start download, stream=True allows for progress tracking
|
24 |
+
response = requests.get(url, stream=True)
|
25 |
+
|
26 |
+
# Check if request was successful
|
27 |
+
response.raise_for_status()
|
28 |
+
|
29 |
+
# Create progress bar
|
30 |
+
total_size = int(response.headers.get('content-length', 0))
|
31 |
+
progress_bar = tqdm(
|
32 |
+
total=total_size,
|
33 |
+
unit='iB',
|
34 |
+
unit_scale=True,
|
35 |
+
ncols=70,
|
36 |
+
file=sys.stdout
|
37 |
+
)
|
38 |
+
|
39 |
+
# Write response content to file
|
40 |
+
with open(filepath, 'wb') as f:
|
41 |
+
for data in response.iter_content(chunk_size=1024):
|
42 |
+
f.write(data)
|
43 |
+
progress_bar.update(len(data)) # Update progress bar
|
44 |
+
|
45 |
+
# Close progress bar
|
46 |
+
progress_bar.close()
|
47 |
+
|
48 |
+
# Error handling for incomplete downloads
|
49 |
+
if total_size != 0 and progress_bar.n != total_size:
|
50 |
+
print("ERROR, something went wrong while downloading")
|
51 |
+
raise Exception()
|
52 |
+
|
53 |
+
|
54 |
+
except Exception as e:
|
55 |
+
print(f"An error occurred: {e}")
|
56 |
+
|
57 |
+
def main():
|
58 |
+
"""Main function to download files from URLs in a config file."""
|
59 |
+
|
60 |
+
# Get JSON config file path
|
61 |
+
script_dir = os.path.dirname(os.path.realpath(__file__))
|
62 |
+
config_file_path = os.path.join(script_dir, "download_models.json")
|
63 |
+
|
64 |
+
# Set download directory
|
65 |
+
download_dir = "checkpoints"
|
66 |
+
os.makedirs(download_dir, exist_ok=True)
|
67 |
+
|
68 |
+
# Load URL and filenames from JSON
|
69 |
+
with open(config_file_path, "r") as f:
|
70 |
+
config = json.load(f)
|
71 |
+
|
72 |
+
# Download each file specified in config
|
73 |
+
for url, filename in config.items():
|
74 |
+
download_file(url, filename, download_dir)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
main()
|
scripts/download_models.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl": "stylegan2_lions_512_pytorch.pkl",
|
3 |
+
"https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl": "stylegan2_dogs_1024_pytorch.pkl",
|
4 |
+
"https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl": "stylegan2_horses_256_pytorch.pkl",
|
5 |
+
"https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl": "stylegan2_elephants_512_pytorch.pkl",
|
6 |
+
"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl": "stylegan2-ffhq-512x512.pkl",
|
7 |
+
"https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl": "stylegan2-afhqcat-512x512.pkl",
|
8 |
+
"http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl": "stylegan2-car-config-f.pkl",
|
9 |
+
"http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl": "stylegan2-cat-config-f.pkl"
|
10 |
+
}
|
scripts/gui.bat
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
python visualizer_drag.py ^
|
3 |
+
checkpoints/stylegan2_lions_512_pytorch.pkl ^
|
4 |
+
checkpoints/stylegan2-ffhq-512x512.pkl ^
|
5 |
+
checkpoints/stylegan2-afhqcat-512x512.pkl ^
|
6 |
+
checkpoints/stylegan2-car-config-f.pkl ^
|
7 |
+
checkpoints/stylegan2_dogs_1024_pytorch.pkl ^
|
8 |
+
checkpoints/stylegan2_horses_256_pytorch.pkl ^
|
9 |
+
checkpoints/stylegan2-cat-config-f.pkl ^
|
10 |
+
checkpoints/stylegan2_elephants_512_pytorch.pkl ^
|
11 |
+
checkpoints/stylegan_human_v2_512.pkl ^
|
12 |
+
checkpoints/stylegan2-lhq-256x256.pkl
|
scripts/gui.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python visualizer_drag.py \
|
2 |
+
checkpoints/stylegan2_lions_512_pytorch.pkl \
|
3 |
+
checkpoints/stylegan2-ffhq-512x512.pkl \
|
4 |
+
checkpoints/stylegan2-afhqcat-512x512.pkl \
|
5 |
+
checkpoints/stylegan2-car-config-f.pkl \
|
6 |
+
checkpoints/stylegan2_dogs_1024_pytorch.pkl \
|
7 |
+
checkpoints/stylegan2_horses_256_pytorch.pkl \
|
8 |
+
checkpoints/stylegan2-cat-config-f.pkl \
|
9 |
+
checkpoints/stylegan2_elephants_512_pytorch.pkl \
|
10 |
+
checkpoints/stylegan_human_v2_512.pkl \
|
11 |
+
checkpoints/stylegan2-lhq-256x256.pkl
|
stylegan_human/.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
__pycache__
|
3 |
+
*.pt
|
4 |
+
*.pth
|
5 |
+
*.pdparams
|
6 |
+
*.pdiparams
|
7 |
+
*.pdmodel
|
8 |
+
*.pkl
|
9 |
+
*.info
|
10 |
+
*.yaml
|
stylegan_human/PP_HumanSeg/deploy/infer.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import codecs
|
19 |
+
import os
|
20 |
+
import time
|
21 |
+
|
22 |
+
import yaml
|
23 |
+
import numpy as np
|
24 |
+
import cv2
|
25 |
+
import paddle
|
26 |
+
import paddleseg.transforms as T
|
27 |
+
from paddle.inference import create_predictor, PrecisionType
|
28 |
+
from paddle.inference import Config as PredictConfig
|
29 |
+
from paddleseg.core.infer import reverse_transform
|
30 |
+
from paddleseg.cvlibs import manager
|
31 |
+
from paddleseg.utils import TimeAverager
|
32 |
+
|
33 |
+
from ..scripts.optic_flow_process import optic_flow_process
|
34 |
+
|
35 |
+
|
36 |
+
class DeployConfig:
|
37 |
+
def __init__(self, path):
|
38 |
+
with codecs.open(path, 'r', 'utf-8') as file:
|
39 |
+
self.dic = yaml.load(file, Loader=yaml.FullLoader)
|
40 |
+
|
41 |
+
self._transforms = self._load_transforms(self.dic['Deploy'][
|
42 |
+
'transforms'])
|
43 |
+
self._dir = os.path.dirname(path)
|
44 |
+
|
45 |
+
@property
|
46 |
+
def transforms(self):
|
47 |
+
return self._transforms
|
48 |
+
|
49 |
+
@property
|
50 |
+
def model(self):
|
51 |
+
return os.path.join(self._dir, self.dic['Deploy']['model'])
|
52 |
+
|
53 |
+
@property
|
54 |
+
def params(self):
|
55 |
+
return os.path.join(self._dir, self.dic['Deploy']['params'])
|
56 |
+
|
57 |
+
def _load_transforms(self, t_list):
|
58 |
+
com = manager.TRANSFORMS
|
59 |
+
transforms = []
|
60 |
+
for t in t_list:
|
61 |
+
ctype = t.pop('type')
|
62 |
+
transforms.append(com[ctype](**t))
|
63 |
+
|
64 |
+
return transforms
|
65 |
+
|
66 |
+
|
67 |
+
class Predictor:
|
68 |
+
def __init__(self, args):
|
69 |
+
self.cfg = DeployConfig(args.cfg)
|
70 |
+
self.args = args
|
71 |
+
self.compose = T.Compose(self.cfg.transforms)
|
72 |
+
resize_h, resize_w = args.input_shape
|
73 |
+
|
74 |
+
self.disflow = cv2.DISOpticalFlow_create(
|
75 |
+
cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
|
76 |
+
self.prev_gray = np.zeros((resize_h, resize_w), np.uint8)
|
77 |
+
self.prev_cfd = np.zeros((resize_h, resize_w), np.float32)
|
78 |
+
self.is_init = True
|
79 |
+
|
80 |
+
pred_cfg = PredictConfig(self.cfg.model, self.cfg.params)
|
81 |
+
pred_cfg.disable_glog_info()
|
82 |
+
if self.args.use_gpu:
|
83 |
+
pred_cfg.enable_use_gpu(100, 0)
|
84 |
+
|
85 |
+
self.predictor = create_predictor(pred_cfg)
|
86 |
+
if self.args.test_speed:
|
87 |
+
self.cost_averager = TimeAverager()
|
88 |
+
|
89 |
+
def preprocess(self, img):
|
90 |
+
ori_shapes = []
|
91 |
+
processed_imgs = []
|
92 |
+
processed_img = self.compose(img)[0]
|
93 |
+
processed_imgs.append(processed_img)
|
94 |
+
ori_shapes.append(img.shape)
|
95 |
+
return processed_imgs, ori_shapes
|
96 |
+
|
97 |
+
def run(self, img, bg):
|
98 |
+
input_names = self.predictor.get_input_names()
|
99 |
+
input_handle = self.predictor.get_input_handle(input_names[0])
|
100 |
+
processed_imgs, ori_shapes = self.preprocess(img)
|
101 |
+
data = np.array(processed_imgs)
|
102 |
+
input_handle.reshape(data.shape)
|
103 |
+
input_handle.copy_from_cpu(data)
|
104 |
+
if self.args.test_speed:
|
105 |
+
start = time.time()
|
106 |
+
|
107 |
+
self.predictor.run()
|
108 |
+
|
109 |
+
if self.args.test_speed:
|
110 |
+
self.cost_averager.record(time.time() - start)
|
111 |
+
output_names = self.predictor.get_output_names()
|
112 |
+
output_handle = self.predictor.get_output_handle(output_names[0])
|
113 |
+
output = output_handle.copy_to_cpu()
|
114 |
+
return self.postprocess(output, img, ori_shapes[0], bg)
|
115 |
+
|
116 |
+
|
117 |
+
def postprocess(self, pred, img, ori_shape, bg):
|
118 |
+
if not os.path.exists(self.args.save_dir):
|
119 |
+
os.makedirs(self.args.save_dir)
|
120 |
+
resize_w = pred.shape[-1]
|
121 |
+
resize_h = pred.shape[-2]
|
122 |
+
if self.args.soft_predict:
|
123 |
+
if self.args.use_optic_flow:
|
124 |
+
score_map = pred[:, 1, :, :].squeeze(0)
|
125 |
+
score_map = 255 * score_map
|
126 |
+
cur_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
127 |
+
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
|
128 |
+
optflow_map = optic_flow_process(cur_gray, score_map, self.prev_gray, self.prev_cfd, \
|
129 |
+
self.disflow, self.is_init)
|
130 |
+
self.prev_gray = cur_gray.copy()
|
131 |
+
self.prev_cfd = optflow_map.copy()
|
132 |
+
self.is_init = False
|
133 |
+
|
134 |
+
score_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
|
135 |
+
score_map = np.transpose(score_map, [2, 0, 1])[np.newaxis, ...]
|
136 |
+
score_map = reverse_transform(
|
137 |
+
paddle.to_tensor(score_map),
|
138 |
+
ori_shape,
|
139 |
+
self.cfg.transforms,
|
140 |
+
mode='bilinear')
|
141 |
+
alpha = np.transpose(score_map.numpy().squeeze(0),
|
142 |
+
[1, 2, 0]) / 255
|
143 |
+
else:
|
144 |
+
score_map = pred[:, 1, :, :]
|
145 |
+
score_map = score_map[np.newaxis, ...]
|
146 |
+
score_map = reverse_transform(
|
147 |
+
paddle.to_tensor(score_map),
|
148 |
+
ori_shape,
|
149 |
+
self.cfg.transforms,
|
150 |
+
mode='bilinear')
|
151 |
+
alpha = np.transpose(score_map.numpy().squeeze(0), [1, 2, 0])
|
152 |
+
|
153 |
+
else:
|
154 |
+
if pred.ndim == 3:
|
155 |
+
pred = pred[:, np.newaxis, ...]
|
156 |
+
result = reverse_transform(
|
157 |
+
paddle.to_tensor(
|
158 |
+
pred, dtype='float32'),
|
159 |
+
ori_shape,
|
160 |
+
self.cfg.transforms,
|
161 |
+
mode='bilinear')
|
162 |
+
|
163 |
+
result = np.array(result)
|
164 |
+
if self.args.add_argmax:
|
165 |
+
result = np.argmax(result, axis=1)
|
166 |
+
else:
|
167 |
+
result = result.squeeze(1)
|
168 |
+
alpha = np.transpose(result, [1, 2, 0])
|
169 |
+
|
170 |
+
# background replace
|
171 |
+
h, w, _ = img.shape
|
172 |
+
if bg is None:
|
173 |
+
bg = np.ones_like(img)*255
|
174 |
+
else:
|
175 |
+
bg = cv2.resize(bg, (w, h))
|
176 |
+
if bg.ndim == 2:
|
177 |
+
bg = bg[..., np.newaxis]
|
178 |
+
|
179 |
+
comb = (alpha * img + (1 - alpha) * bg).astype(np.uint8)
|
180 |
+
return comb, alpha, bg, img
|
stylegan_human/PP_HumanSeg/export_model/download_export_model.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf8
|
2 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import sys
|
17 |
+
import os
|
18 |
+
|
19 |
+
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
|
20 |
+
TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
|
21 |
+
sys.path.append(TEST_PATH)
|
22 |
+
|
23 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
24 |
+
|
25 |
+
model_urls = {
|
26 |
+
"pphumanseg_lite_portrait_398x224_with_softmax":
|
27 |
+
"https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz",
|
28 |
+
"deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax":
|
29 |
+
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/deeplabv3p_resnet50_os8_humanseg_512x512_100k_with_softmax.zip",
|
30 |
+
"fcn_hrnetw18_small_v1_humanseg_192x192_with_softmax":
|
31 |
+
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/fcn_hrnetw18_small_v1_humanseg_192x192_with_softmax.zip",
|
32 |
+
"pphumanseg_lite_generic_humanseg_192x192_with_softmax":
|
33 |
+
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/export/pphumanseg_lite_generic_192x192_with_softmax.zip",
|
34 |
+
}
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
for model_name, url in model_urls.items():
|
38 |
+
download_file_and_uncompress(
|
39 |
+
url=url,
|
40 |
+
savepath=LOCAL_PATH,
|
41 |
+
extrapath=LOCAL_PATH,
|
42 |
+
extraname=model_name)
|
43 |
+
|
44 |
+
print("Export model download success!")
|
stylegan_human/PP_HumanSeg/pretrained_model/download_pretrained_model.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf8
|
2 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import sys
|
17 |
+
import os
|
18 |
+
|
19 |
+
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
|
20 |
+
TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
|
21 |
+
sys.path.append(TEST_PATH)
|
22 |
+
|
23 |
+
from paddleseg.utils.download import download_file_and_uncompress
|
24 |
+
|
25 |
+
model_urls = {
|
26 |
+
"pphumanseg_lite_portrait_398x224":
|
27 |
+
"https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224.tar.gz",
|
28 |
+
"deeplabv3p_resnet50_os8_humanseg_512x512_100k":
|
29 |
+
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/deeplabv3p_resnet50_os8_humanseg_512x512_100k.zip",
|
30 |
+
"fcn_hrnetw18_small_v1_humanseg_192x192":
|
31 |
+
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/fcn_hrnetw18_small_v1_humanseg_192x192.zip",
|
32 |
+
"pphumanseg_lite_generic_human_192x192":
|
33 |
+
"https://paddleseg.bj.bcebos.com/dygraph/humanseg/train/pphumanseg_lite_generic_192x192.zip",
|
34 |
+
}
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
for model_name, url in model_urls.items():
|
38 |
+
download_file_and_uncompress(
|
39 |
+
url=url,
|
40 |
+
savepath=LOCAL_PATH,
|
41 |
+
extrapath=LOCAL_PATH,
|
42 |
+
extraname=model_name)
|
43 |
+
|
44 |
+
print("Pretrained model download success!")
|