Add spar3d demo files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +167 -0
- .pre-commit-config.yaml +24 -0
- LICENSE.md +51 -0
- README.md +3 -3
- __init__.py +358 -0
- demo_files/comp.gif +3 -0
- demo_files/examples/bird.png +3 -0
- demo_files/examples/castle.png +3 -0
- demo_files/examples/chest.png +3 -0
- demo_files/examples/doll.png +3 -0
- demo_files/examples/excavator.png +3 -0
- demo_files/examples/fish.png +3 -0
- demo_files/examples/horse-statue.png +3 -0
- demo_files/examples/penguin.png +3 -0
- demo_files/examples/pot.png +3 -0
- demo_files/examples/raccoon_wizard.png +3 -0
- demo_files/examples/stylized-rocks.png +3 -0
- demo_files/hdri/abandoned_tiled_room_1k.hdr +0 -0
- demo_files/hdri/metro_noord_1k.hdr +0 -0
- demo_files/hdri/neon_photostudio_1k.hdr +0 -0
- demo_files/hdri/peppermint_powerplant_1k.hdr +0 -0
- demo_files/hdri/rainforest_trail_1k.hdr +0 -0
- demo_files/hdri/studio_small_08_1k.hdr +0 -0
- demo_files/hdri/urban_alley_01_1k.hdr +0 -0
- demo_files/turntable.gif +3 -0
- demo_files/workflows/spar3d_example.json +263 -0
- gradio_app.py +792 -0
- load/tets/160_tets.npz +3 -0
- requirements.txt +17 -0
- ruff.toml +3 -0
- run.py +180 -0
- spar3d/models/camera.py +32 -0
- spar3d/models/diffusion/gaussian_diffusion.py +524 -0
- spar3d/models/diffusion/sampler.py +134 -0
- spar3d/models/global_estimator/reni_estimator.py +112 -0
- spar3d/models/illumination/reni/components/film_siren.py +148 -0
- spar3d/models/illumination/reni/components/siren.py +118 -0
- spar3d/models/illumination/reni/components/transformer_decoder.py +189 -0
- spar3d/models/illumination/reni/components/vn_layers.py +548 -0
- spar3d/models/illumination/reni/env_map.py +93 -0
- spar3d/models/illumination/reni/field.py +736 -0
- spar3d/models/image_estimator/clip_based_estimator.py +184 -0
- spar3d/models/isosurface.py +229 -0
- spar3d/models/mesh.py +317 -0
- spar3d/models/network.py +223 -0
- spar3d/models/tokenizers/dinov2.py +1196 -0
- spar3d/models/tokenizers/image.py +99 -0
- spar3d/models/tokenizers/point.py +51 -0
- spar3d/models/tokenizers/triplane.py +49 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv*/
|
127 |
+
env/
|
128 |
+
venv*/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
#.idea/
|
162 |
+
.vs/
|
163 |
+
.idea/
|
164 |
+
.vscode/
|
165 |
+
|
166 |
+
stabilityai/
|
167 |
+
output/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_language_version:
|
2 |
+
python: python3
|
3 |
+
|
4 |
+
repos:
|
5 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
6 |
+
rev: v4.4.0
|
7 |
+
hooks:
|
8 |
+
- id: trailing-whitespace
|
9 |
+
- id: check-ast
|
10 |
+
- id: check-merge-conflict
|
11 |
+
- id: check-yaml
|
12 |
+
- id: end-of-file-fixer
|
13 |
+
- id: trailing-whitespace
|
14 |
+
args: [--markdown-linebreak-ext=md]
|
15 |
+
|
16 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
17 |
+
# Ruff version.
|
18 |
+
rev: v0.3.5
|
19 |
+
hooks:
|
20 |
+
# Run the linter.
|
21 |
+
- id: ruff
|
22 |
+
args: [ --fix ]
|
23 |
+
# Run the formatter.
|
24 |
+
- id: ruff-format
|
LICENSE.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
STABILITY AI COMMUNITY LICENSE AGREEMENT
|
2 |
+
Last Updated: July 5, 2024
|
3 |
+
|
4 |
+
|
5 |
+
I. INTRODUCTION
|
6 |
+
|
7 |
+
This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
|
8 |
+
|
9 |
+
|
10 |
+
This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
|
11 |
+
|
12 |
+
|
13 |
+
By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
|
14 |
+
|
15 |
+
II. RESEARCH & NON-COMMERCIAL USE LICENSE
|
16 |
+
|
17 |
+
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
|
18 |
+
|
19 |
+
III. COMMERCIAL USE LICENSE
|
20 |
+
|
21 |
+
Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
|
22 |
+
If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
|
23 |
+
|
24 |
+
IV. GENERAL TERMS
|
25 |
+
|
26 |
+
Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
|
27 |
+
a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
|
28 |
+
b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
|
29 |
+
c. Intellectual Property.
|
30 |
+
(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
|
31 |
+
(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
|
32 |
+
(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
|
33 |
+
(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
|
34 |
+
(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
|
35 |
+
d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
36 |
+
e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
37 |
+
f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
|
38 |
+
g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
|
39 |
+
|
40 |
+
V. DEFINITIONS
|
41 |
+
|
42 |
+
"Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
|
43 |
+
"Agreement" means this Stability AI Community License Agreement.
|
44 |
+
"AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
45 |
+
"Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
|
46 |
+
"Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
|
47 |
+
"Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
|
48 |
+
"Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
|
49 |
+
"Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
|
50 |
+
"Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
|
51 |
+
"Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
README.md
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
---
|
2 |
-
title: Stable Point
|
3 |
emoji: ⚡
|
4 |
colorFrom: yellow
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
1 |
---
|
2 |
+
title: Stable Point-Aware 3D
|
3 |
emoji: ⚡
|
4 |
colorFrom: yellow
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.43.0
|
8 |
+
app_file: gradio_app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
__init__.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import comfy.model_management
|
8 |
+
import folder_paths
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import trimesh
|
12 |
+
from PIL import Image
|
13 |
+
from trimesh.exchange import gltf
|
14 |
+
|
15 |
+
sys.path.append(os.path.dirname(__file__))
|
16 |
+
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
|
17 |
+
from spar3d.system import SPAR3D
|
18 |
+
from spar3d.utils import foreground_crop
|
19 |
+
|
20 |
+
SPAR3D_CATEGORY = "SPAR3D"
|
21 |
+
SPAR3D_MODEL_NAME = "stabilityai/spar3d"
|
22 |
+
|
23 |
+
|
24 |
+
class SPAR3DLoader:
|
25 |
+
CATEGORY = SPAR3D_CATEGORY
|
26 |
+
FUNCTION = "load"
|
27 |
+
RETURN_NAMES = ("spar3d_model",)
|
28 |
+
RETURN_TYPES = ("SPAR3D_MODEL",)
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def INPUT_TYPES(cls):
|
32 |
+
return {"required": {}}
|
33 |
+
|
34 |
+
def load(self):
|
35 |
+
device = comfy.model_management.get_torch_device()
|
36 |
+
model = SPAR3D.from_pretrained(
|
37 |
+
SPAR3D_MODEL_NAME,
|
38 |
+
config_name="config.yaml",
|
39 |
+
weight_name="model.safetensors",
|
40 |
+
)
|
41 |
+
model.to(device)
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
return (model,)
|
45 |
+
|
46 |
+
|
47 |
+
class SPAR3DPreview:
|
48 |
+
CATEGORY = SPAR3D_CATEGORY
|
49 |
+
FUNCTION = "preview"
|
50 |
+
OUTPUT_NODE = True
|
51 |
+
RETURN_TYPES = ()
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def INPUT_TYPES(s):
|
55 |
+
return {"required": {"mesh": ("MESH",)}}
|
56 |
+
|
57 |
+
def preview(self, mesh):
|
58 |
+
glbs = []
|
59 |
+
for m in mesh:
|
60 |
+
scene = trimesh.Scene(m)
|
61 |
+
glb_data = gltf.export_glb(scene, include_normals=True)
|
62 |
+
glb_base64 = base64.b64encode(glb_data).decode("utf-8")
|
63 |
+
glbs.append(glb_base64)
|
64 |
+
return {"ui": {"glbs": glbs}}
|
65 |
+
|
66 |
+
|
67 |
+
class SPAR3DSampler:
|
68 |
+
CATEGORY = SPAR3D_CATEGORY
|
69 |
+
FUNCTION = "predict"
|
70 |
+
RETURN_NAMES = ("mesh", "pointcloud")
|
71 |
+
RETURN_TYPES = ("MESH", "POINTCLOUD")
|
72 |
+
|
73 |
+
@classmethod
|
74 |
+
def INPUT_TYPES(s):
|
75 |
+
remesh_choices = ["none"]
|
76 |
+
if TRIANGLE_REMESH_AVAILABLE:
|
77 |
+
remesh_choices.append("triangle")
|
78 |
+
if QUAD_REMESH_AVAILABLE:
|
79 |
+
remesh_choices.append("quad")
|
80 |
+
|
81 |
+
opt_dict = {
|
82 |
+
"mask": ("MASK",),
|
83 |
+
"pointcloud": ("POINTCLOUD",),
|
84 |
+
"target_type": (["none", "vertex", "face"],),
|
85 |
+
"target_count": (
|
86 |
+
"INT",
|
87 |
+
{"default": 1000, "min": 3, "max": 20000, "step": 1},
|
88 |
+
),
|
89 |
+
"guidance_scale": (
|
90 |
+
"FLOAT",
|
91 |
+
{"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05},
|
92 |
+
),
|
93 |
+
"seed": (
|
94 |
+
"INT",
|
95 |
+
{"default": 42, "min": 0, "max": 2**32 - 1, "step": 1},
|
96 |
+
),
|
97 |
+
}
|
98 |
+
if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
|
99 |
+
opt_dict["remesh"] = (remesh_choices,)
|
100 |
+
|
101 |
+
return {
|
102 |
+
"required": {
|
103 |
+
"model": ("SPAR3D_MODEL",),
|
104 |
+
"image": ("IMAGE",),
|
105 |
+
"foreground_ratio": (
|
106 |
+
"FLOAT",
|
107 |
+
{"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01},
|
108 |
+
),
|
109 |
+
"texture_resolution": (
|
110 |
+
"INT",
|
111 |
+
{"default": 1024, "min": 512, "max": 2048, "step": 256},
|
112 |
+
),
|
113 |
+
},
|
114 |
+
"optional": opt_dict,
|
115 |
+
}
|
116 |
+
|
117 |
+
def predict(
|
118 |
+
s,
|
119 |
+
model,
|
120 |
+
image,
|
121 |
+
mask,
|
122 |
+
foreground_ratio,
|
123 |
+
texture_resolution,
|
124 |
+
pointcloud=None,
|
125 |
+
remesh="none",
|
126 |
+
target_type="none",
|
127 |
+
target_count=1000,
|
128 |
+
guidance_scale=3.0,
|
129 |
+
seed=42,
|
130 |
+
):
|
131 |
+
if image.shape[0] != 1:
|
132 |
+
raise ValueError("Only one image can be processed at a time")
|
133 |
+
|
134 |
+
vertex_count = (
|
135 |
+
-1
|
136 |
+
if target_type == "none"
|
137 |
+
else (target_count // 2 if target_type == "face" else target_count)
|
138 |
+
)
|
139 |
+
|
140 |
+
pil_image = Image.fromarray(
|
141 |
+
torch.clamp(torch.round(255.0 * image[0]), 0, 255)
|
142 |
+
.type(torch.uint8)
|
143 |
+
.cpu()
|
144 |
+
.numpy()
|
145 |
+
)
|
146 |
+
|
147 |
+
if mask is not None:
|
148 |
+
print("Using Mask")
|
149 |
+
mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
|
150 |
+
np.uint8
|
151 |
+
)
|
152 |
+
mask_pil = Image.fromarray(mask_np, mode="L")
|
153 |
+
pil_image.putalpha(mask_pil)
|
154 |
+
else:
|
155 |
+
if image.shape[3] != 4:
|
156 |
+
print("No mask or alpha channel detected, Converting to RGBA")
|
157 |
+
pil_image = pil_image.convert("RGBA")
|
158 |
+
|
159 |
+
pil_image = foreground_crop(pil_image, foreground_ratio)
|
160 |
+
|
161 |
+
model.cfg.guidance_scale = guidance_scale
|
162 |
+
random.seed(seed)
|
163 |
+
torch.manual_seed(seed)
|
164 |
+
np.random.seed(seed)
|
165 |
+
|
166 |
+
print(remesh)
|
167 |
+
with torch.no_grad():
|
168 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
169 |
+
if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle":
|
170 |
+
raise ImportError(
|
171 |
+
"Triangle remeshing requires gpytoolbox to be installed"
|
172 |
+
)
|
173 |
+
if not QUAD_REMESH_AVAILABLE and remesh == "quad":
|
174 |
+
raise ImportError("Quad remeshing requires pynim to be installed")
|
175 |
+
mesh, glob_dict = model.run_image(
|
176 |
+
pil_image,
|
177 |
+
bake_resolution=texture_resolution,
|
178 |
+
pointcloud=pointcloud,
|
179 |
+
remesh=remesh,
|
180 |
+
vertex_count=vertex_count,
|
181 |
+
)
|
182 |
+
|
183 |
+
if mesh.vertices.shape[0] == 0:
|
184 |
+
raise ValueError("No subject detected in the image")
|
185 |
+
|
186 |
+
return (
|
187 |
+
[mesh],
|
188 |
+
glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(),
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
class SPAR3DSave:
|
193 |
+
CATEGORY = SPAR3D_CATEGORY
|
194 |
+
FUNCTION = "save"
|
195 |
+
OUTPUT_NODE = True
|
196 |
+
RETURN_TYPES = ()
|
197 |
+
|
198 |
+
@classmethod
|
199 |
+
def INPUT_TYPES(s):
|
200 |
+
return {
|
201 |
+
"required": {
|
202 |
+
"mesh": ("MESH",),
|
203 |
+
"filename_prefix": ("STRING", {"default": "SPAR3D"}),
|
204 |
+
}
|
205 |
+
}
|
206 |
+
|
207 |
+
def __init__(self):
|
208 |
+
self.type = "output"
|
209 |
+
|
210 |
+
def save(self, mesh, filename_prefix):
|
211 |
+
output_dir = folder_paths.get_output_directory()
|
212 |
+
glbs = []
|
213 |
+
for idx, m in enumerate(mesh):
|
214 |
+
scene = trimesh.Scene(m)
|
215 |
+
glb_data = gltf.export_glb(scene, include_normals=True)
|
216 |
+
logging.info(f"Generated GLB model with {len(glb_data)} bytes")
|
217 |
+
|
218 |
+
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
219 |
+
folder_paths.get_save_image_path(filename_prefix, output_dir)
|
220 |
+
)
|
221 |
+
filename = filename.replace("%batch_num%", str(idx))
|
222 |
+
out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
|
223 |
+
with open(out_path, "wb") as f:
|
224 |
+
f.write(glb_data)
|
225 |
+
glbs.append(base64.b64encode(glb_data).decode("utf-8"))
|
226 |
+
return {"ui": {"glbs": glbs}}
|
227 |
+
|
228 |
+
|
229 |
+
class SPAR3DPointCloudLoader:
|
230 |
+
CATEGORY = SPAR3D_CATEGORY
|
231 |
+
FUNCTION = "load_pointcloud"
|
232 |
+
RETURN_TYPES = ("POINTCLOUD",)
|
233 |
+
RETURN_NAMES = ("pointcloud",)
|
234 |
+
|
235 |
+
@classmethod
|
236 |
+
def INPUT_TYPES(cls):
|
237 |
+
return {
|
238 |
+
"required": {
|
239 |
+
"file": ("STRING", {"default": None}),
|
240 |
+
}
|
241 |
+
}
|
242 |
+
|
243 |
+
def load_pointcloud(self, file):
|
244 |
+
if file is None or file == "":
|
245 |
+
return (None,)
|
246 |
+
# Load the mesh using trimesh
|
247 |
+
mesh = trimesh.load(file)
|
248 |
+
|
249 |
+
# Extract vertices and colors
|
250 |
+
vertices = mesh.vertices
|
251 |
+
|
252 |
+
# Get vertex colors, defaulting to white if none exist
|
253 |
+
if mesh.visual.vertex_colors is not None:
|
254 |
+
colors = (
|
255 |
+
mesh.visual.vertex_colors[:, :3] / 255.0
|
256 |
+
) # Convert 0-255 to 0-1 range
|
257 |
+
else:
|
258 |
+
colors = np.ones((len(vertices), 3))
|
259 |
+
|
260 |
+
# Interleave XYZ and RGB values
|
261 |
+
point_cloud = []
|
262 |
+
for vertex, color in zip(vertices, colors):
|
263 |
+
point_cloud.extend(
|
264 |
+
[
|
265 |
+
float(vertex[0]),
|
266 |
+
float(vertex[1]),
|
267 |
+
float(vertex[2]),
|
268 |
+
float(color[0]),
|
269 |
+
float(color[1]),
|
270 |
+
float(color[2]),
|
271 |
+
]
|
272 |
+
)
|
273 |
+
|
274 |
+
return (point_cloud,)
|
275 |
+
|
276 |
+
|
277 |
+
class SPAR3DPointCloudSaver:
|
278 |
+
CATEGORY = SPAR3D_CATEGORY
|
279 |
+
FUNCTION = "save_pointcloud"
|
280 |
+
OUTPUT_NODE = True
|
281 |
+
RETURN_TYPES = ()
|
282 |
+
|
283 |
+
@classmethod
|
284 |
+
def INPUT_TYPES(s):
|
285 |
+
return {
|
286 |
+
"required": {
|
287 |
+
"pointcloud": ("POINTCLOUD",),
|
288 |
+
"filename_prefix": ("STRING", {"default": "SPAR3D"}),
|
289 |
+
}
|
290 |
+
}
|
291 |
+
|
292 |
+
def save_pointcloud(self, pointcloud, filename_prefix):
|
293 |
+
if pointcloud is None:
|
294 |
+
return {"ui": {"text": "No point cloud data to save"}}
|
295 |
+
|
296 |
+
# Reshape the flat list into points with XYZ and RGB
|
297 |
+
points = np.array(pointcloud).reshape(-1, 6)
|
298 |
+
|
299 |
+
# Create vertex array for PLY
|
300 |
+
vertex_array = np.zeros(
|
301 |
+
len(points),
|
302 |
+
dtype=[
|
303 |
+
("x", "f4"),
|
304 |
+
("y", "f4"),
|
305 |
+
("z", "f4"),
|
306 |
+
("red", "u1"),
|
307 |
+
("green", "u1"),
|
308 |
+
("blue", "u1"),
|
309 |
+
],
|
310 |
+
)
|
311 |
+
|
312 |
+
# Fill vertex array
|
313 |
+
vertex_array["x"] = points[:, 0]
|
314 |
+
vertex_array["y"] = points[:, 1]
|
315 |
+
vertex_array["z"] = points[:, 2]
|
316 |
+
# Convert RGB from 0-1 to 0-255 range
|
317 |
+
vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8)
|
318 |
+
vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8)
|
319 |
+
vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8)
|
320 |
+
|
321 |
+
# Create PLY object
|
322 |
+
ply_data = trimesh.PointCloud(
|
323 |
+
vertices=points[:, :3], colors=points[:, 3:] * 255
|
324 |
+
)
|
325 |
+
|
326 |
+
# Save to file
|
327 |
+
output_dir = folder_paths.get_output_directory()
|
328 |
+
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
329 |
+
folder_paths.get_save_image_path(filename_prefix, output_dir)
|
330 |
+
)
|
331 |
+
out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply")
|
332 |
+
|
333 |
+
ply_data.export(out_path)
|
334 |
+
|
335 |
+
return {"ui": {"text": f"Saved point cloud to {out_path}"}}
|
336 |
+
|
337 |
+
|
338 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
339 |
+
"SPAR3DLoader": "SPAR3D Loader",
|
340 |
+
"SPAR3DPreview": "SPAR3D Preview",
|
341 |
+
"SPAR3DSampler": "SPAR3D Sampler",
|
342 |
+
"SPAR3DSave": "SPAR3D Save",
|
343 |
+
"SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader",
|
344 |
+
"SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver",
|
345 |
+
}
|
346 |
+
|
347 |
+
NODE_CLASS_MAPPINGS = {
|
348 |
+
"SPAR3DLoader": SPAR3DLoader,
|
349 |
+
"SPAR3DPreview": SPAR3DPreview,
|
350 |
+
"SPAR3DSampler": SPAR3DSampler,
|
351 |
+
"SPAR3DSave": SPAR3DSave,
|
352 |
+
"SPAR3DPointCloudLoader": SPAR3DPointCloudLoader,
|
353 |
+
"SPAR3DPointCloudSaver": SPAR3DPointCloudSaver,
|
354 |
+
}
|
355 |
+
|
356 |
+
WEB_DIRECTORY = "./comfyui"
|
357 |
+
|
358 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
|
demo_files/comp.gif
ADDED
Git LFS Details
|
demo_files/examples/bird.png
ADDED
Git LFS Details
|
demo_files/examples/castle.png
ADDED
Git LFS Details
|
demo_files/examples/chest.png
ADDED
Git LFS Details
|
demo_files/examples/doll.png
ADDED
Git LFS Details
|
demo_files/examples/excavator.png
ADDED
Git LFS Details
|
demo_files/examples/fish.png
ADDED
Git LFS Details
|
demo_files/examples/horse-statue.png
ADDED
Git LFS Details
|
demo_files/examples/penguin.png
ADDED
Git LFS Details
|
demo_files/examples/pot.png
ADDED
Git LFS Details
|
demo_files/examples/raccoon_wizard.png
ADDED
Git LFS Details
|
demo_files/examples/stylized-rocks.png
ADDED
Git LFS Details
|
demo_files/hdri/abandoned_tiled_room_1k.hdr
ADDED
Binary file (478 kB). View file
|
|
demo_files/hdri/metro_noord_1k.hdr
ADDED
Binary file (467 kB). View file
|
|
demo_files/hdri/neon_photostudio_1k.hdr
ADDED
Binary file (438 kB). View file
|
|
demo_files/hdri/peppermint_powerplant_1k.hdr
ADDED
Binary file (473 kB). View file
|
|
demo_files/hdri/rainforest_trail_1k.hdr
ADDED
Binary file (512 kB). View file
|
|
demo_files/hdri/studio_small_08_1k.hdr
ADDED
Binary file (412 kB). View file
|
|
demo_files/hdri/urban_alley_01_1k.hdr
ADDED
Binary file (458 kB). View file
|
|
demo_files/turntable.gif
ADDED
Git LFS Details
|
demo_files/workflows/spar3d_example.json
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"last_node_id": 17,
|
3 |
+
"last_link_id": 18,
|
4 |
+
"nodes": [
|
5 |
+
{
|
6 |
+
"id": 10,
|
7 |
+
"type": "SPAR3DLoader",
|
8 |
+
"pos": [
|
9 |
+
52.92446517944336,
|
10 |
+
394.328369140625
|
11 |
+
],
|
12 |
+
"size": [
|
13 |
+
210,
|
14 |
+
26
|
15 |
+
],
|
16 |
+
"flags": {},
|
17 |
+
"order": 0,
|
18 |
+
"mode": 0,
|
19 |
+
"inputs": [],
|
20 |
+
"outputs": [
|
21 |
+
{
|
22 |
+
"name": "spar3d_model",
|
23 |
+
"type": "SPAR3D_MODEL",
|
24 |
+
"links": [
|
25 |
+
10
|
26 |
+
],
|
27 |
+
"slot_index": 0
|
28 |
+
}
|
29 |
+
],
|
30 |
+
"properties": {
|
31 |
+
"Node name for S&R": "SPAR3DLoader"
|
32 |
+
},
|
33 |
+
"widgets_values": []
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"id": 13,
|
37 |
+
"type": "LoadImage",
|
38 |
+
"pos": [
|
39 |
+
-43.437347412109375,
|
40 |
+
482.89678955078125
|
41 |
+
],
|
42 |
+
"size": [
|
43 |
+
315,
|
44 |
+
314
|
45 |
+
],
|
46 |
+
"flags": {},
|
47 |
+
"order": 1,
|
48 |
+
"mode": 0,
|
49 |
+
"inputs": [],
|
50 |
+
"outputs": [
|
51 |
+
{
|
52 |
+
"name": "IMAGE",
|
53 |
+
"type": "IMAGE",
|
54 |
+
"links": [
|
55 |
+
11
|
56 |
+
],
|
57 |
+
"slot_index": 0
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"name": "MASK",
|
61 |
+
"type": "MASK",
|
62 |
+
"links": [
|
63 |
+
16
|
64 |
+
],
|
65 |
+
"slot_index": 1
|
66 |
+
}
|
67 |
+
],
|
68 |
+
"properties": {
|
69 |
+
"Node name for S&R": "LoadImage"
|
70 |
+
},
|
71 |
+
"widgets_values": [
|
72 |
+
"cat1.png",
|
73 |
+
"image"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"id": 16,
|
78 |
+
"type": "InvertMask",
|
79 |
+
"pos": [
|
80 |
+
377.1180419921875,
|
81 |
+
605.384765625
|
82 |
+
],
|
83 |
+
"size": [
|
84 |
+
210,
|
85 |
+
26
|
86 |
+
],
|
87 |
+
"flags": {},
|
88 |
+
"order": 2,
|
89 |
+
"mode": 0,
|
90 |
+
"inputs": [
|
91 |
+
{
|
92 |
+
"name": "mask",
|
93 |
+
"type": "MASK",
|
94 |
+
"link": 16
|
95 |
+
}
|
96 |
+
],
|
97 |
+
"outputs": [
|
98 |
+
{
|
99 |
+
"name": "MASK",
|
100 |
+
"type": "MASK",
|
101 |
+
"links": [
|
102 |
+
17
|
103 |
+
],
|
104 |
+
"slot_index": 0
|
105 |
+
}
|
106 |
+
],
|
107 |
+
"properties": {
|
108 |
+
"Node name for S&R": "InvertMask"
|
109 |
+
},
|
110 |
+
"widgets_values": []
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"id": 17,
|
114 |
+
"type": "SPAR3DSave",
|
115 |
+
"pos": [
|
116 |
+
1133.669921875,
|
117 |
+
439.6551513671875
|
118 |
+
],
|
119 |
+
"size": [
|
120 |
+
315,
|
121 |
+
58
|
122 |
+
],
|
123 |
+
"flags": {},
|
124 |
+
"order": 4,
|
125 |
+
"mode": 0,
|
126 |
+
"inputs": [
|
127 |
+
{
|
128 |
+
"name": "mesh",
|
129 |
+
"type": "MESH",
|
130 |
+
"link": 18
|
131 |
+
}
|
132 |
+
],
|
133 |
+
"outputs": [],
|
134 |
+
"properties": {
|
135 |
+
"Node name for S&R": "SPAR3DSave"
|
136 |
+
},
|
137 |
+
"widgets_values": [
|
138 |
+
"SPAR3D"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"id": 11,
|
143 |
+
"type": "SPAR3DSampler",
|
144 |
+
"pos": [
|
145 |
+
673.0637817382812,
|
146 |
+
441.2229309082031
|
147 |
+
],
|
148 |
+
"size": [
|
149 |
+
315,
|
150 |
+
286
|
151 |
+
],
|
152 |
+
"flags": {},
|
153 |
+
"order": 3,
|
154 |
+
"mode": 0,
|
155 |
+
"inputs": [
|
156 |
+
{
|
157 |
+
"name": "model",
|
158 |
+
"type": "SPAR3D_MODEL",
|
159 |
+
"link": 10
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"name": "image",
|
163 |
+
"type": "IMAGE",
|
164 |
+
"link": 11
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"name": "mask",
|
168 |
+
"type": "MASK",
|
169 |
+
"link": 17,
|
170 |
+
"shape": 7
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"name": "pointcloud",
|
174 |
+
"type": "POINTCLOUD",
|
175 |
+
"link": null,
|
176 |
+
"shape": 7
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"outputs": [
|
180 |
+
{
|
181 |
+
"name": "mesh",
|
182 |
+
"type": "MESH",
|
183 |
+
"links": [
|
184 |
+
18
|
185 |
+
],
|
186 |
+
"slot_index": 0
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"name": "pointcloud",
|
190 |
+
"type": "POINTCLOUD",
|
191 |
+
"links": null
|
192 |
+
}
|
193 |
+
],
|
194 |
+
"properties": {
|
195 |
+
"Node name for S&R": "SPAR3DSampler"
|
196 |
+
},
|
197 |
+
"widgets_values": [
|
198 |
+
1.3,
|
199 |
+
1024,
|
200 |
+
"none",
|
201 |
+
1000,
|
202 |
+
3,
|
203 |
+
3727502160,
|
204 |
+
"randomize",
|
205 |
+
"none"
|
206 |
+
]
|
207 |
+
}
|
208 |
+
],
|
209 |
+
"links": [
|
210 |
+
[
|
211 |
+
10,
|
212 |
+
10,
|
213 |
+
0,
|
214 |
+
11,
|
215 |
+
0,
|
216 |
+
"SPAR3D_MODEL"
|
217 |
+
],
|
218 |
+
[
|
219 |
+
11,
|
220 |
+
13,
|
221 |
+
0,
|
222 |
+
11,
|
223 |
+
1,
|
224 |
+
"IMAGE"
|
225 |
+
],
|
226 |
+
[
|
227 |
+
16,
|
228 |
+
13,
|
229 |
+
1,
|
230 |
+
16,
|
231 |
+
0,
|
232 |
+
"MASK"
|
233 |
+
],
|
234 |
+
[
|
235 |
+
17,
|
236 |
+
16,
|
237 |
+
0,
|
238 |
+
11,
|
239 |
+
2,
|
240 |
+
"MASK"
|
241 |
+
],
|
242 |
+
[
|
243 |
+
18,
|
244 |
+
11,
|
245 |
+
0,
|
246 |
+
17,
|
247 |
+
0,
|
248 |
+
"MESH"
|
249 |
+
]
|
250 |
+
],
|
251 |
+
"groups": [],
|
252 |
+
"config": {},
|
253 |
+
"extra": {
|
254 |
+
"ds": {
|
255 |
+
"scale": 0.953502721998243,
|
256 |
+
"offset": [
|
257 |
+
266.21995970220667,
|
258 |
+
116.75398112171928
|
259 |
+
]
|
260 |
+
}
|
261 |
+
},
|
262 |
+
"version": 0.4
|
263 |
+
}
|
gradio_app.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.system("pip install ./texture_baker/ ./uv_unwrapper/")
|
4 |
+
|
5 |
+
import random
|
6 |
+
import tempfile
|
7 |
+
import time
|
8 |
+
from contextlib import nullcontext
|
9 |
+
from functools import lru_cache
|
10 |
+
from typing import Any
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import trimesh
|
16 |
+
from gradio_litmodel3d import LitModel3D
|
17 |
+
from gradio_pointcloudeditor import PointCloudEditor
|
18 |
+
from PIL import Image
|
19 |
+
from transparent_background import Remover
|
20 |
+
|
21 |
+
import spar3d.utils as spar3d_utils
|
22 |
+
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
|
23 |
+
from spar3d.system import SPAR3D
|
24 |
+
|
25 |
+
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
|
26 |
+
|
27 |
+
bg_remover = Remover() # default setting
|
28 |
+
|
29 |
+
COND_WIDTH = 512
|
30 |
+
COND_HEIGHT = 512
|
31 |
+
COND_DISTANCE = 2.2
|
32 |
+
COND_FOVY = 0.591627
|
33 |
+
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
|
34 |
+
|
35 |
+
# Cached. Doesn't change
|
36 |
+
c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
|
37 |
+
intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
|
38 |
+
COND_FOVY, COND_HEIGHT, COND_WIDTH
|
39 |
+
)
|
40 |
+
|
41 |
+
generated_files = []
|
42 |
+
|
43 |
+
# Delete previous gradio temp dir folder
|
44 |
+
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
45 |
+
print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
|
46 |
+
import shutil
|
47 |
+
|
48 |
+
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
|
49 |
+
|
50 |
+
device = spar3d_utils.get_device()
|
51 |
+
|
52 |
+
model = SPAR3D.from_pretrained(
|
53 |
+
"stabilityai/stable-point-aware-3d",
|
54 |
+
config_name="config.yaml",
|
55 |
+
weight_name="model.safetensors",
|
56 |
+
)
|
57 |
+
model.eval()
|
58 |
+
model = model.to(device)
|
59 |
+
|
60 |
+
example_files = [
|
61 |
+
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
|
62 |
+
]
|
63 |
+
|
64 |
+
|
65 |
+
def forward_model(
|
66 |
+
batch,
|
67 |
+
system,
|
68 |
+
guidance_scale=3.0,
|
69 |
+
seed=0,
|
70 |
+
device="cuda",
|
71 |
+
remesh_option="none",
|
72 |
+
vertex_count=-1,
|
73 |
+
texture_resolution=1024,
|
74 |
+
):
|
75 |
+
batch_size = batch["rgb_cond"].shape[0]
|
76 |
+
|
77 |
+
# prepare the condition for point cloud generation
|
78 |
+
# set seed
|
79 |
+
random.seed(seed)
|
80 |
+
torch.manual_seed(seed)
|
81 |
+
np.random.seed(seed)
|
82 |
+
cond_tokens = system.forward_pdiff_cond(batch)
|
83 |
+
|
84 |
+
if "pc_cond" not in batch:
|
85 |
+
sample_iter = system.sampler.sample_batch_progressive(
|
86 |
+
batch_size,
|
87 |
+
cond_tokens,
|
88 |
+
guidance_scale=guidance_scale,
|
89 |
+
device=device,
|
90 |
+
)
|
91 |
+
for x in sample_iter:
|
92 |
+
samples = x["xstart"]
|
93 |
+
batch["pc_cond"] = samples.permute(0, 2, 1).float()
|
94 |
+
batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"])
|
95 |
+
|
96 |
+
# subsample to the 512 points
|
97 |
+
batch["pc_cond"] = batch["pc_cond"][
|
98 |
+
:, torch.randperm(batch["pc_cond"].shape[1])[:512]
|
99 |
+
]
|
100 |
+
|
101 |
+
# get the point cloud
|
102 |
+
xyz = batch["pc_cond"][0, :, :3].cpu().numpy()
|
103 |
+
color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8)
|
104 |
+
pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb)
|
105 |
+
|
106 |
+
# forward for the final mesh
|
107 |
+
trimesh_mesh, _glob_dict = model.generate_mesh(
|
108 |
+
batch, texture_resolution, remesh=remesh_option, vertex_count=vertex_count
|
109 |
+
)
|
110 |
+
trimesh_mesh = trimesh_mesh[0]
|
111 |
+
|
112 |
+
return trimesh_mesh, pc_rgb_trimesh
|
113 |
+
|
114 |
+
|
115 |
+
def run_model(
|
116 |
+
input_image,
|
117 |
+
guidance_scale,
|
118 |
+
random_seed,
|
119 |
+
pc_cond,
|
120 |
+
remesh_option,
|
121 |
+
vertex_count,
|
122 |
+
texture_resolution,
|
123 |
+
):
|
124 |
+
start = time.time()
|
125 |
+
with torch.no_grad():
|
126 |
+
with (
|
127 |
+
torch.autocast(device_type=device, dtype=torch.float16)
|
128 |
+
if "cuda" in device
|
129 |
+
else nullcontext()
|
130 |
+
):
|
131 |
+
model_batch = create_batch(input_image)
|
132 |
+
model_batch = {k: v.to(device) for k, v in model_batch.items()}
|
133 |
+
|
134 |
+
if pc_cond is not None:
|
135 |
+
# Check if pc_cond is a list
|
136 |
+
if isinstance(pc_cond, list):
|
137 |
+
cond_tensor = torch.tensor(pc_cond).float().cuda().view(-1, 6)
|
138 |
+
xyz = cond_tensor[:, :3]
|
139 |
+
color_rgb = cond_tensor[:, 3:]
|
140 |
+
elif isinstance(pc_cond, dict):
|
141 |
+
xyz = torch.tensor(pc_cond["positions"]).float().cuda()
|
142 |
+
color_rgb = torch.tensor(pc_cond["colors"]).float().cuda()
|
143 |
+
else:
|
144 |
+
xyz = torch.tensor(pc_cond.vertices).float().cuda()
|
145 |
+
color_rgb = (
|
146 |
+
torch.tensor(pc_cond.colors[:, :3]).float().cuda() / 255.0
|
147 |
+
)
|
148 |
+
model_batch["pc_cond"] = torch.cat([xyz, color_rgb], dim=-1).unsqueeze(
|
149 |
+
0
|
150 |
+
)
|
151 |
+
# sub-sample the point cloud to the target number of points
|
152 |
+
if model_batch["pc_cond"].shape[1] > 512:
|
153 |
+
idx = torch.randperm(model_batch["pc_cond"].shape[1])[:512]
|
154 |
+
model_batch["pc_cond"] = model_batch["pc_cond"][:, idx]
|
155 |
+
elif model_batch["pc_cond"].shape[1] < 512:
|
156 |
+
num_points = model_batch["pc_cond"].shape[1]
|
157 |
+
gr.Warning(
|
158 |
+
f"The uploaded point cloud should have at least 512 points. This point cloud only has {num_points}. Results may be worse."
|
159 |
+
)
|
160 |
+
pad = 512 - num_points
|
161 |
+
sampled_idx = torch.randint(
|
162 |
+
0, model_batch["pc_cond"].shape[1], (pad,)
|
163 |
+
)
|
164 |
+
model_batch["pc_cond"] = torch.cat(
|
165 |
+
[
|
166 |
+
model_batch["pc_cond"],
|
167 |
+
model_batch["pc_cond"][:, sampled_idx],
|
168 |
+
],
|
169 |
+
dim=1,
|
170 |
+
)
|
171 |
+
|
172 |
+
trimesh_mesh, trimesh_pc = forward_model(
|
173 |
+
model_batch,
|
174 |
+
model,
|
175 |
+
guidance_scale=guidance_scale,
|
176 |
+
seed=random_seed,
|
177 |
+
device="cuda",
|
178 |
+
remesh_option=remesh_option.lower(),
|
179 |
+
vertex_count=vertex_count,
|
180 |
+
texture_resolution=texture_resolution,
|
181 |
+
)
|
182 |
+
|
183 |
+
# Create new tmp file
|
184 |
+
temp_dir = tempfile.mkdtemp()
|
185 |
+
tmp_file = os.path.join(temp_dir, "mesh.glb")
|
186 |
+
|
187 |
+
trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True)
|
188 |
+
generated_files.append(tmp_file)
|
189 |
+
|
190 |
+
tmp_file_pc = os.path.join(temp_dir, "points.ply")
|
191 |
+
trimesh_pc.export(tmp_file_pc)
|
192 |
+
generated_files.append(tmp_file_pc)
|
193 |
+
|
194 |
+
print("Generation took:", time.time() - start, "s")
|
195 |
+
|
196 |
+
return tmp_file, tmp_file_pc, trimesh_pc
|
197 |
+
|
198 |
+
|
199 |
+
def create_batch(input_image: Image) -> dict[str, Any]:
|
200 |
+
img_cond = (
|
201 |
+
torch.from_numpy(
|
202 |
+
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
|
203 |
+
/ 255.0
|
204 |
+
)
|
205 |
+
.float()
|
206 |
+
.clip(0, 1)
|
207 |
+
)
|
208 |
+
mask_cond = img_cond[:, :, -1:]
|
209 |
+
rgb_cond = torch.lerp(
|
210 |
+
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
|
211 |
+
)
|
212 |
+
|
213 |
+
batch_elem = {
|
214 |
+
"rgb_cond": rgb_cond,
|
215 |
+
"mask_cond": mask_cond,
|
216 |
+
"c2w_cond": c2w_cond.unsqueeze(0),
|
217 |
+
"intrinsic_cond": intrinsic.unsqueeze(0),
|
218 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
|
219 |
+
}
|
220 |
+
# Add batch dim
|
221 |
+
batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
|
222 |
+
return batched
|
223 |
+
|
224 |
+
|
225 |
+
@lru_cache
|
226 |
+
def checkerboard(squares: int, size: int, min_value: float = 0.5):
|
227 |
+
base = np.zeros((squares, squares)) + min_value
|
228 |
+
base[1::2, ::2] = 1
|
229 |
+
base[::2, 1::2] = 1
|
230 |
+
|
231 |
+
repeat_mult = size // squares
|
232 |
+
return (
|
233 |
+
base.repeat(repeat_mult, axis=0)
|
234 |
+
.repeat(repeat_mult, axis=1)[:, :, None]
|
235 |
+
.repeat(3, axis=-1)
|
236 |
+
)
|
237 |
+
|
238 |
+
|
239 |
+
def remove_background(input_image: Image) -> Image:
|
240 |
+
return bg_remover.process(input_image.convert("RGB"))
|
241 |
+
|
242 |
+
|
243 |
+
def show_mask_img(input_image: Image) -> Image:
|
244 |
+
img_numpy = np.array(input_image)
|
245 |
+
alpha = img_numpy[:, :, 3] / 255.0
|
246 |
+
chkb = checkerboard(32, 512) * 255
|
247 |
+
new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
|
248 |
+
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
|
249 |
+
|
250 |
+
|
251 |
+
def process_model_run(
|
252 |
+
background_state,
|
253 |
+
guidance_scale,
|
254 |
+
random_seed,
|
255 |
+
pc_cond,
|
256 |
+
remesh_option,
|
257 |
+
vertex_count_type,
|
258 |
+
vertex_count,
|
259 |
+
texture_resolution,
|
260 |
+
):
|
261 |
+
# Adjust vertex count based on selection
|
262 |
+
final_vertex_count = (
|
263 |
+
-1
|
264 |
+
if vertex_count_type == "Keep Vertex Count"
|
265 |
+
else (
|
266 |
+
vertex_count // 2
|
267 |
+
if vertex_count_type == "Target Face Count"
|
268 |
+
else vertex_count
|
269 |
+
)
|
270 |
+
)
|
271 |
+
print(
|
272 |
+
f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
|
273 |
+
)
|
274 |
+
|
275 |
+
glb_file, pc_file, pc_plot = run_model(
|
276 |
+
background_state,
|
277 |
+
guidance_scale,
|
278 |
+
random_seed,
|
279 |
+
pc_cond,
|
280 |
+
remesh_option,
|
281 |
+
final_vertex_count,
|
282 |
+
texture_resolution,
|
283 |
+
)
|
284 |
+
# Create a single float list of x y z r g b
|
285 |
+
point_list = []
|
286 |
+
for i in range(pc_plot.vertices.shape[0]):
|
287 |
+
point_list.extend(
|
288 |
+
[
|
289 |
+
pc_plot.vertices[i, 0],
|
290 |
+
pc_plot.vertices[i, 1],
|
291 |
+
pc_plot.vertices[i, 2],
|
292 |
+
pc_plot.colors[i, 0] / 255.0,
|
293 |
+
pc_plot.colors[i, 1] / 255.0,
|
294 |
+
pc_plot.colors[i, 2] / 255.0,
|
295 |
+
]
|
296 |
+
)
|
297 |
+
|
298 |
+
return glb_file, pc_file, point_list
|
299 |
+
|
300 |
+
|
301 |
+
def regenerate_run(
|
302 |
+
background_state,
|
303 |
+
guidance_scale,
|
304 |
+
random_seed,
|
305 |
+
pc_cond,
|
306 |
+
remesh_option,
|
307 |
+
vertex_count_type,
|
308 |
+
vertex_count,
|
309 |
+
texture_resolution,
|
310 |
+
):
|
311 |
+
glb_file, pc_file, point_list = process_model_run(
|
312 |
+
background_state,
|
313 |
+
guidance_scale,
|
314 |
+
random_seed,
|
315 |
+
pc_cond,
|
316 |
+
remesh_option,
|
317 |
+
vertex_count_type,
|
318 |
+
vertex_count,
|
319 |
+
texture_resolution,
|
320 |
+
)
|
321 |
+
return (
|
322 |
+
gr.update(), # run_btn
|
323 |
+
gr.update(), # img_proc_state
|
324 |
+
gr.update(), # background_remove_state
|
325 |
+
gr.update(), # preview_removal
|
326 |
+
gr.update(value=glb_file, visible=True), # output_3d
|
327 |
+
gr.update(visible=True), # hdr_row
|
328 |
+
gr.update(visible=True), # point_cloud_row
|
329 |
+
gr.update(value=point_list), # point_cloud_editor
|
330 |
+
gr.update(value=pc_file), # pc_download
|
331 |
+
gr.update(visible=False), # regenerate_btn
|
332 |
+
)
|
333 |
+
|
334 |
+
|
335 |
+
def run_button(
|
336 |
+
run_btn,
|
337 |
+
input_image,
|
338 |
+
background_state,
|
339 |
+
foreground_ratio,
|
340 |
+
no_crop,
|
341 |
+
guidance_scale,
|
342 |
+
random_seed,
|
343 |
+
pc_upload,
|
344 |
+
pc_cond_file,
|
345 |
+
remesh_option,
|
346 |
+
vertex_count_type,
|
347 |
+
vertex_count,
|
348 |
+
texture_resolution,
|
349 |
+
):
|
350 |
+
if run_btn == "Run":
|
351 |
+
if torch.cuda.is_available():
|
352 |
+
torch.cuda.reset_peak_memory_stats()
|
353 |
+
|
354 |
+
if pc_upload:
|
355 |
+
# make sure the pc_cond_file has been uploaded
|
356 |
+
try:
|
357 |
+
pc_cond = trimesh.load(pc_cond_file.name)
|
358 |
+
except Exception:
|
359 |
+
raise gr.Error(
|
360 |
+
"Please upload a valid point cloud ply file as condition."
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
pc_cond = None
|
364 |
+
|
365 |
+
glb_file, pc_file, pc_list = process_model_run(
|
366 |
+
background_state,
|
367 |
+
guidance_scale,
|
368 |
+
random_seed,
|
369 |
+
pc_cond,
|
370 |
+
remesh_option,
|
371 |
+
vertex_count_type,
|
372 |
+
vertex_count,
|
373 |
+
texture_resolution,
|
374 |
+
)
|
375 |
+
|
376 |
+
if torch.cuda.is_available():
|
377 |
+
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
|
378 |
+
elif torch.backends.mps.is_available():
|
379 |
+
print(
|
380 |
+
"Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
|
381 |
+
)
|
382 |
+
|
383 |
+
return (
|
384 |
+
gr.update(), # run_btn
|
385 |
+
gr.update(), # img_proc_state
|
386 |
+
gr.update(), # background_remove_state
|
387 |
+
gr.update(), # preview_removal
|
388 |
+
gr.update(value=glb_file, visible=True), # output_3d
|
389 |
+
gr.update(visible=True), # hdr_row
|
390 |
+
gr.update(visible=True), # point_cloud_row
|
391 |
+
gr.update(value=pc_list), # point_cloud_editor
|
392 |
+
gr.update(value=pc_file), # pc_download
|
393 |
+
gr.update(visible=False), # regenerate_btn
|
394 |
+
)
|
395 |
+
|
396 |
+
elif run_btn == "Remove Background":
|
397 |
+
rem_removed = remove_background(input_image)
|
398 |
+
|
399 |
+
fr_res = spar3d_utils.foreground_crop(
|
400 |
+
rem_removed,
|
401 |
+
crop_ratio=foreground_ratio,
|
402 |
+
newsize=(COND_WIDTH, COND_HEIGHT),
|
403 |
+
no_crop=no_crop,
|
404 |
+
)
|
405 |
+
|
406 |
+
return (
|
407 |
+
gr.update(value="Run", visible=True), # run_btn
|
408 |
+
rem_removed, # img_proc_state,
|
409 |
+
fr_res, # background_remove_state
|
410 |
+
gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
|
411 |
+
gr.update(value=None, visible=False), # output_3d
|
412 |
+
gr.update(visible=False), # hdr_row
|
413 |
+
gr.update(visible=False), # point_cloud_row
|
414 |
+
gr.update(value=None), # point_cloud_editor
|
415 |
+
gr.update(value=None), # pc_download
|
416 |
+
gr.update(visible=False), # regenerate_btn
|
417 |
+
)
|
418 |
+
|
419 |
+
|
420 |
+
def requires_bg_remove(image, fr, no_crop):
|
421 |
+
if image is None:
|
422 |
+
return (
|
423 |
+
gr.update(visible=False, value="Run"), # run_Btn
|
424 |
+
None, # img_proc_state
|
425 |
+
None, # background_remove_state
|
426 |
+
gr.update(value=None, visible=False), # preview_removal
|
427 |
+
gr.update(value=None, visible=False), # output_3d
|
428 |
+
gr.update(visible=False), # hdr_row
|
429 |
+
gr.update(visible=False), # point_cloud_row
|
430 |
+
gr.update(value=None), # point_cloud_editor
|
431 |
+
gr.update(value=None), # pc_download
|
432 |
+
gr.update(visible=False), # regenerate_btn
|
433 |
+
)
|
434 |
+
alpha_channel = np.array(image.getchannel("A"))
|
435 |
+
min_alpha = alpha_channel.min()
|
436 |
+
|
437 |
+
if min_alpha == 0:
|
438 |
+
print("Already has alpha")
|
439 |
+
fr_res = spar3d_utils.foreground_crop(
|
440 |
+
image, fr, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop
|
441 |
+
)
|
442 |
+
return (
|
443 |
+
gr.update(value="Run", visible=True), # run_Btn
|
444 |
+
image, # img_proc_state
|
445 |
+
fr_res, # background_remove_state
|
446 |
+
gr.update(value=show_mask_img(fr_res), visible=True), # preview_removal
|
447 |
+
gr.update(value=None, visible=False), # output_3d
|
448 |
+
gr.update(visible=False), # hdr_row
|
449 |
+
gr.update(visible=False), # point_cloud_row
|
450 |
+
gr.update(value=None), # point_cloud_editor
|
451 |
+
gr.update(value=None), # pc_download
|
452 |
+
gr.update(visible=False), # regenerate_btn
|
453 |
+
)
|
454 |
+
return (
|
455 |
+
gr.update(value="Remove Background", visible=True), # run_Btn
|
456 |
+
None, # img_proc_state
|
457 |
+
None, # background_remove_state
|
458 |
+
gr.update(value=None, visible=False), # preview_removal
|
459 |
+
gr.update(value=None, visible=False), # output_3d
|
460 |
+
gr.update(visible=False), # hdr_row
|
461 |
+
gr.update(visible=False), # point_cloud_row
|
462 |
+
gr.update(value=None), # point_cloud_editor
|
463 |
+
gr.update(value=None), # pc_download
|
464 |
+
gr.update(visible=False), # regenerate_btn
|
465 |
+
)
|
466 |
+
|
467 |
+
|
468 |
+
def update_foreground_ratio(img_proc, fr, no_crop):
|
469 |
+
foreground_res = spar3d_utils.foreground_crop(
|
470 |
+
img_proc, fr, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop
|
471 |
+
)
|
472 |
+
return (
|
473 |
+
foreground_res,
|
474 |
+
gr.update(value=show_mask_img(foreground_res)),
|
475 |
+
)
|
476 |
+
|
477 |
+
|
478 |
+
def update_resolution_controls(remesh_choice, vertex_count_type):
|
479 |
+
show_controls = remesh_choice.lower() != "none"
|
480 |
+
show_vertex_count = vertex_count_type != "Keep Vertex Count"
|
481 |
+
return (
|
482 |
+
gr.update(visible=show_controls), # vertex_count_type
|
483 |
+
gr.update(visible=show_controls and show_vertex_count), # vertex_count_slider
|
484 |
+
)
|
485 |
+
|
486 |
+
|
487 |
+
with gr.Blocks() as demo:
|
488 |
+
img_proc_state = gr.State()
|
489 |
+
background_remove_state = gr.State()
|
490 |
+
gr.Markdown(
|
491 |
+
"""
|
492 |
+
# SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images
|
493 |
+
|
494 |
+
SPAR3D is a state-of-the-art method for 3D mesh reconstruction from a single image. This demo allows you to upload an image and generate a 3D mesh model from it. A feature of SPAR3D is it generates point clouds as intermediate representation before producing the mesh. You can edit the point cloud to adjust the final mesh. We provide a simple point cloud editor in this demo, where you can drag, recolor and rescale the point clouds. If you have more advanced editing needs (e.g. box selection, duplication, local streching, etc.), you can download the point cloud and edit it in softwares such as MeshLab or Blender. The edited point cloud can then be uploaded to this demo to generate a new 3D model by checking the "Point cloud upload" box.
|
495 |
+
|
496 |
+
**Tips**
|
497 |
+
|
498 |
+
1. If the image does not have a valid alpha channel, it will go through the background removal step. Our built-in background removal can be inaccurate sometimes, which will result in poor mesh quality. In such cases, you can use external background removal tools to obtain a RGBA image before uploading here.
|
499 |
+
2. You can adjust the foreground ratio to control the size of the foreground object. This may have major impact on the final mesh.
|
500 |
+
3. Guidance scale controls the strength of the image condition in the point cloud generation process. A higher value may result in higher mesh fidelity, but the variability by changing the random seed will be lower. Note that the guidance scale and the seed are not effective when the point cloud is manually uploaded.
|
501 |
+
4. Our online editor supports multi-selection by holding down the shift key. This allows you to recolor multiple points at once.
|
502 |
+
5. The editing should mainly alter the unseen parts of the object. Visible parts can be edited, but the edits should be consistent with the image. Editing the visible parts in a way that contradicts the image may result in poor mesh quality.
|
503 |
+
6. You can upload your own HDR environment map to light the 3D model.
|
504 |
+
"""
|
505 |
+
)
|
506 |
+
with gr.Row(variant="panel"):
|
507 |
+
with gr.Column():
|
508 |
+
with gr.Row():
|
509 |
+
input_img = gr.Image(
|
510 |
+
type="pil", label="Input Image", sources="upload", image_mode="RGBA"
|
511 |
+
)
|
512 |
+
preview_removal = gr.Image(
|
513 |
+
label="Preview Background Removal",
|
514 |
+
type="pil",
|
515 |
+
image_mode="RGB",
|
516 |
+
interactive=False,
|
517 |
+
visible=False,
|
518 |
+
)
|
519 |
+
|
520 |
+
gr.Markdown("### Input Controls")
|
521 |
+
with gr.Group():
|
522 |
+
with gr.Row():
|
523 |
+
no_crop = gr.Checkbox(label="No cropping", value=False)
|
524 |
+
pc_upload = gr.Checkbox(label="Point cloud upload", value=False)
|
525 |
+
|
526 |
+
pc_cond_file = gr.File(
|
527 |
+
label="Point Cloud Upload",
|
528 |
+
file_types=[".ply"],
|
529 |
+
file_count="single",
|
530 |
+
visible=False,
|
531 |
+
)
|
532 |
+
|
533 |
+
foreground_ratio = gr.Slider(
|
534 |
+
label="Padding Ratio",
|
535 |
+
minimum=1.0,
|
536 |
+
maximum=2.0,
|
537 |
+
value=1.3,
|
538 |
+
step=0.05,
|
539 |
+
)
|
540 |
+
|
541 |
+
pc_upload.change(
|
542 |
+
lambda x: gr.update(visible=x),
|
543 |
+
inputs=pc_upload,
|
544 |
+
outputs=[pc_cond_file],
|
545 |
+
)
|
546 |
+
|
547 |
+
no_crop.change(
|
548 |
+
update_foreground_ratio,
|
549 |
+
inputs=[img_proc_state, foreground_ratio, no_crop],
|
550 |
+
outputs=[background_remove_state, preview_removal],
|
551 |
+
)
|
552 |
+
|
553 |
+
foreground_ratio.change(
|
554 |
+
update_foreground_ratio,
|
555 |
+
inputs=[img_proc_state, foreground_ratio, no_crop],
|
556 |
+
outputs=[background_remove_state, preview_removal],
|
557 |
+
)
|
558 |
+
|
559 |
+
gr.Markdown("### Point Diffusion Controls")
|
560 |
+
with gr.Group():
|
561 |
+
guidance_scale = gr.Slider(
|
562 |
+
label="Guidance Scale",
|
563 |
+
minimum=1.0,
|
564 |
+
maximum=10.0,
|
565 |
+
value=3.0,
|
566 |
+
step=1.0,
|
567 |
+
)
|
568 |
+
|
569 |
+
random_seed = gr.Slider(
|
570 |
+
label="Seed",
|
571 |
+
minimum=0,
|
572 |
+
maximum=10000,
|
573 |
+
value=0,
|
574 |
+
step=1,
|
575 |
+
)
|
576 |
+
|
577 |
+
no_remesh = not TRIANGLE_REMESH_AVAILABLE and not QUAD_REMESH_AVAILABLE
|
578 |
+
gr.Markdown(
|
579 |
+
"### Texture Controls"
|
580 |
+
if no_remesh
|
581 |
+
else "### Meshing and Texture Controls"
|
582 |
+
)
|
583 |
+
with gr.Group():
|
584 |
+
remesh_choices = ["None"]
|
585 |
+
if TRIANGLE_REMESH_AVAILABLE:
|
586 |
+
remesh_choices.append("Triangle")
|
587 |
+
if QUAD_REMESH_AVAILABLE:
|
588 |
+
remesh_choices.append("Quad")
|
589 |
+
|
590 |
+
remesh_option = gr.Radio(
|
591 |
+
choices=remesh_choices,
|
592 |
+
label="Remeshing",
|
593 |
+
value="None",
|
594 |
+
visible=not no_remesh,
|
595 |
+
)
|
596 |
+
|
597 |
+
vertex_count_type = gr.Radio(
|
598 |
+
choices=[
|
599 |
+
"Keep Vertex Count",
|
600 |
+
"Target Vertex Count",
|
601 |
+
"Target Face Count",
|
602 |
+
],
|
603 |
+
label="Mesh Resolution Control",
|
604 |
+
value="Keep Vertex Count",
|
605 |
+
visible=False,
|
606 |
+
)
|
607 |
+
|
608 |
+
vertex_count_slider = gr.Slider(
|
609 |
+
label="Target Count",
|
610 |
+
minimum=0,
|
611 |
+
maximum=20000,
|
612 |
+
value=2000,
|
613 |
+
visible=False,
|
614 |
+
)
|
615 |
+
|
616 |
+
texture_size = gr.Slider(
|
617 |
+
label="Texture Size",
|
618 |
+
minimum=512,
|
619 |
+
maximum=2048,
|
620 |
+
value=1024,
|
621 |
+
step=256,
|
622 |
+
visible=True,
|
623 |
+
)
|
624 |
+
|
625 |
+
remesh_option.change(
|
626 |
+
update_resolution_controls,
|
627 |
+
inputs=[remesh_option, vertex_count_type],
|
628 |
+
outputs=[vertex_count_type, vertex_count_slider],
|
629 |
+
)
|
630 |
+
|
631 |
+
vertex_count_type.change(
|
632 |
+
update_resolution_controls,
|
633 |
+
inputs=[remesh_option, vertex_count_type],
|
634 |
+
outputs=[vertex_count_type, vertex_count_slider],
|
635 |
+
)
|
636 |
+
|
637 |
+
run_btn = gr.Button("Run", variant="primary", visible=False)
|
638 |
+
|
639 |
+
with gr.Column():
|
640 |
+
with gr.Group(visible=False) as point_cloud_row:
|
641 |
+
point_size_slider = gr.Slider(
|
642 |
+
label="Point Size",
|
643 |
+
minimum=0.01,
|
644 |
+
maximum=1.0,
|
645 |
+
value=0.2,
|
646 |
+
step=0.01,
|
647 |
+
)
|
648 |
+
point_cloud_editor = PointCloudEditor(
|
649 |
+
up_axis="Z",
|
650 |
+
forward_axis="X",
|
651 |
+
lock_scale_z=True,
|
652 |
+
lock_scale_y=True,
|
653 |
+
visible=True,
|
654 |
+
)
|
655 |
+
|
656 |
+
pc_download = gr.File(
|
657 |
+
label="Point Cloud Download",
|
658 |
+
file_types=[".ply"],
|
659 |
+
file_count="single",
|
660 |
+
)
|
661 |
+
point_size_slider.change(
|
662 |
+
fn=lambda x: gr.update(point_size=x),
|
663 |
+
inputs=point_size_slider,
|
664 |
+
outputs=point_cloud_editor,
|
665 |
+
)
|
666 |
+
|
667 |
+
regenerate_btn = gr.Button(
|
668 |
+
"Re-run with point cloud", variant="primary", visible=False
|
669 |
+
)
|
670 |
+
|
671 |
+
output_3d = LitModel3D(
|
672 |
+
label="3D Model",
|
673 |
+
visible=False,
|
674 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
675 |
+
tonemapping="aces",
|
676 |
+
contrast=1.0,
|
677 |
+
scale=1.0,
|
678 |
+
)
|
679 |
+
with gr.Column(visible=False, scale=1.0) as hdr_row:
|
680 |
+
gr.Markdown(
|
681 |
+
"""## HDR Environment Map
|
682 |
+
|
683 |
+
Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
|
684 |
+
"""
|
685 |
+
)
|
686 |
+
|
687 |
+
with gr.Row():
|
688 |
+
hdr_illumination_file = gr.File(
|
689 |
+
label="HDR Env Map",
|
690 |
+
file_types=[".hdr"],
|
691 |
+
file_count="single",
|
692 |
+
)
|
693 |
+
example_hdris = [
|
694 |
+
os.path.join("demo_files/hdri", f)
|
695 |
+
for f in os.listdir("demo_files/hdri")
|
696 |
+
]
|
697 |
+
hdr_illumination_example = gr.Examples(
|
698 |
+
examples=example_hdris,
|
699 |
+
inputs=hdr_illumination_file,
|
700 |
+
)
|
701 |
+
|
702 |
+
hdr_illumination_file.change(
|
703 |
+
lambda x: gr.update(env_map=x.name if x is not None else None),
|
704 |
+
inputs=hdr_illumination_file,
|
705 |
+
outputs=[output_3d],
|
706 |
+
)
|
707 |
+
|
708 |
+
examples = gr.Examples(
|
709 |
+
examples=example_files, inputs=input_img, examples_per_page=11
|
710 |
+
)
|
711 |
+
|
712 |
+
input_img.change(
|
713 |
+
requires_bg_remove,
|
714 |
+
inputs=[input_img, foreground_ratio, no_crop],
|
715 |
+
outputs=[
|
716 |
+
run_btn,
|
717 |
+
img_proc_state,
|
718 |
+
background_remove_state,
|
719 |
+
preview_removal,
|
720 |
+
output_3d,
|
721 |
+
hdr_row,
|
722 |
+
point_cloud_row,
|
723 |
+
point_cloud_editor,
|
724 |
+
pc_download,
|
725 |
+
regenerate_btn,
|
726 |
+
],
|
727 |
+
)
|
728 |
+
|
729 |
+
point_cloud_editor.edit(
|
730 |
+
fn=lambda _x: gr.update(visible=True),
|
731 |
+
inputs=point_cloud_editor,
|
732 |
+
outputs=regenerate_btn,
|
733 |
+
)
|
734 |
+
|
735 |
+
regenerate_btn.click(
|
736 |
+
regenerate_run,
|
737 |
+
inputs=[
|
738 |
+
background_remove_state,
|
739 |
+
guidance_scale,
|
740 |
+
random_seed,
|
741 |
+
point_cloud_editor,
|
742 |
+
remesh_option,
|
743 |
+
vertex_count_type,
|
744 |
+
vertex_count_slider,
|
745 |
+
texture_size,
|
746 |
+
],
|
747 |
+
outputs=[
|
748 |
+
run_btn,
|
749 |
+
img_proc_state,
|
750 |
+
background_remove_state,
|
751 |
+
preview_removal,
|
752 |
+
output_3d,
|
753 |
+
hdr_row,
|
754 |
+
point_cloud_row,
|
755 |
+
point_cloud_editor,
|
756 |
+
pc_download,
|
757 |
+
regenerate_btn,
|
758 |
+
],
|
759 |
+
)
|
760 |
+
|
761 |
+
run_btn.click(
|
762 |
+
run_button,
|
763 |
+
inputs=[
|
764 |
+
run_btn,
|
765 |
+
input_img,
|
766 |
+
background_remove_state,
|
767 |
+
foreground_ratio,
|
768 |
+
no_crop,
|
769 |
+
guidance_scale,
|
770 |
+
random_seed,
|
771 |
+
pc_upload,
|
772 |
+
pc_cond_file,
|
773 |
+
remesh_option,
|
774 |
+
vertex_count_type,
|
775 |
+
vertex_count_slider,
|
776 |
+
texture_size,
|
777 |
+
],
|
778 |
+
outputs=[
|
779 |
+
run_btn,
|
780 |
+
img_proc_state,
|
781 |
+
background_remove_state,
|
782 |
+
preview_removal,
|
783 |
+
output_3d,
|
784 |
+
hdr_row,
|
785 |
+
point_cloud_row,
|
786 |
+
point_cloud_editor,
|
787 |
+
pc_download,
|
788 |
+
regenerate_btn,
|
789 |
+
],
|
790 |
+
)
|
791 |
+
|
792 |
+
demo.queue().launch()
|
load/tets/160_tets.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
|
3 |
+
size 15408790
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.7.0
|
2 |
+
jaxtyping==0.2.31
|
3 |
+
omegaconf==2.3.0
|
4 |
+
transformers==4.42.3
|
5 |
+
loralib==0.1.2
|
6 |
+
git+https://github.com/openai/CLIP.git
|
7 |
+
git+https://github.com/SunzeY/AlphaCLIP.git
|
8 |
+
trimesh==4.4.1
|
9 |
+
numpy==1.26.4
|
10 |
+
huggingface-hub==0.23.4
|
11 |
+
transparent-background==1.3.3
|
12 |
+
gradio==4.43.0
|
13 |
+
gradio-litmodel3d==0.0.1
|
14 |
+
gradio-pointcloudeditor==0.0.9
|
15 |
+
gpytoolbox==0.2.0
|
16 |
+
# ./texture_baker/
|
17 |
+
# ./uv_unwrapper/
|
ruff.toml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[lint]
|
2 |
+
ignore = ["F722", "F821"]
|
3 |
+
extend-select = ["I"]
|
run.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from contextlib import nullcontext
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
from transparent_background import Remover
|
9 |
+
|
10 |
+
from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
|
11 |
+
from spar3d.system import SPAR3D
|
12 |
+
from spar3d.utils import foreground_crop, get_device, remove_background
|
13 |
+
|
14 |
+
|
15 |
+
def check_positive(value):
|
16 |
+
ivalue = int(value)
|
17 |
+
if ivalue <= 0:
|
18 |
+
raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
|
19 |
+
return ivalue
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument(
|
25 |
+
"image", type=str, nargs="+", help="Path to input image(s) or folder."
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--device",
|
29 |
+
default=get_device(),
|
30 |
+
type=str,
|
31 |
+
help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--pretrained-model",
|
35 |
+
default="stabilityai/spar3d",
|
36 |
+
type=str,
|
37 |
+
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/spar3d'",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--foreground-ratio",
|
41 |
+
default=1.3,
|
42 |
+
type=float,
|
43 |
+
help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--output-dir",
|
47 |
+
default="output/",
|
48 |
+
type=str,
|
49 |
+
help="Output directory to save the results. Default: 'output/'",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--texture-resolution",
|
53 |
+
default=1024,
|
54 |
+
type=int,
|
55 |
+
help="Texture atlas resolution. Default: 1024",
|
56 |
+
)
|
57 |
+
|
58 |
+
remesh_choices = ["none"]
|
59 |
+
if TRIANGLE_REMESH_AVAILABLE:
|
60 |
+
remesh_choices.append("triangle")
|
61 |
+
if QUAD_REMESH_AVAILABLE:
|
62 |
+
remesh_choices.append("quad")
|
63 |
+
parser.add_argument(
|
64 |
+
"--remesh_option",
|
65 |
+
choices=remesh_choices,
|
66 |
+
default="none",
|
67 |
+
help="Remeshing option",
|
68 |
+
)
|
69 |
+
if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE:
|
70 |
+
parser.add_argument(
|
71 |
+
"--reduction_count_type",
|
72 |
+
choices=["keep", "vertex", "faces"],
|
73 |
+
default="keep",
|
74 |
+
help="Vertex count type",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--target_count",
|
78 |
+
type=check_positive,
|
79 |
+
help="Selected target count.",
|
80 |
+
default=2000,
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--batch_size", default=1, type=int, help="Batch size for inference"
|
84 |
+
)
|
85 |
+
args = parser.parse_args()
|
86 |
+
|
87 |
+
# Ensure args.device contains cuda
|
88 |
+
devices = ["cuda", "mps", "cpu"]
|
89 |
+
if not any(args.device in device for device in devices):
|
90 |
+
raise ValueError("Invalid device. Use cuda, mps or cpu")
|
91 |
+
|
92 |
+
output_dir = args.output_dir
|
93 |
+
os.makedirs(output_dir, exist_ok=True)
|
94 |
+
|
95 |
+
device = args.device
|
96 |
+
if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
|
97 |
+
device = "cpu"
|
98 |
+
|
99 |
+
print("Device used: ", device)
|
100 |
+
|
101 |
+
model = SPAR3D.from_pretrained(
|
102 |
+
args.pretrained_model,
|
103 |
+
config_name="config.yaml",
|
104 |
+
weight_name="model.safetensors",
|
105 |
+
)
|
106 |
+
model.to(device)
|
107 |
+
model.eval()
|
108 |
+
|
109 |
+
bg_remover = Remover(device=device)
|
110 |
+
images = []
|
111 |
+
idx = 0
|
112 |
+
for image_path in args.image:
|
113 |
+
|
114 |
+
def handle_image(image_path, idx):
|
115 |
+
image = remove_background(
|
116 |
+
Image.open(image_path).convert("RGBA"), bg_remover
|
117 |
+
)
|
118 |
+
image = foreground_crop(image, args.foreground_ratio)
|
119 |
+
os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
|
120 |
+
image.save(os.path.join(output_dir, str(idx), "input.png"))
|
121 |
+
images.append(image)
|
122 |
+
|
123 |
+
if os.path.isdir(image_path):
|
124 |
+
image_paths = [
|
125 |
+
os.path.join(image_path, f)
|
126 |
+
for f in os.listdir(image_path)
|
127 |
+
if f.endswith((".png", ".jpg", ".jpeg"))
|
128 |
+
]
|
129 |
+
for image_path in image_paths:
|
130 |
+
handle_image(image_path, idx)
|
131 |
+
idx += 1
|
132 |
+
else:
|
133 |
+
handle_image(image_path, idx)
|
134 |
+
idx += 1
|
135 |
+
|
136 |
+
vertex_count = (
|
137 |
+
-1
|
138 |
+
if args.reduction_count_type == "keep"
|
139 |
+
else (
|
140 |
+
args.target_count
|
141 |
+
if args.reduction_count_type == "vertex"
|
142 |
+
else args.target_count // 2
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
+
for i in tqdm(range(0, len(images), args.batch_size)):
|
147 |
+
image = images[i : i + args.batch_size]
|
148 |
+
if torch.cuda.is_available():
|
149 |
+
torch.cuda.reset_peak_memory_stats()
|
150 |
+
with torch.no_grad():
|
151 |
+
with (
|
152 |
+
torch.autocast(device_type=device, dtype=torch.float16)
|
153 |
+
if "cuda" in device
|
154 |
+
else nullcontext()
|
155 |
+
):
|
156 |
+
mesh, glob_dict = model.run_image(
|
157 |
+
image,
|
158 |
+
bake_resolution=args.texture_resolution,
|
159 |
+
remesh=args.remesh_option,
|
160 |
+
vertex_count=args.target_vertex_count,
|
161 |
+
return_points=True,
|
162 |
+
)
|
163 |
+
if torch.cuda.is_available():
|
164 |
+
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
|
165 |
+
elif torch.backends.mps.is_available():
|
166 |
+
print(
|
167 |
+
"Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
|
168 |
+
)
|
169 |
+
|
170 |
+
if len(image) == 1:
|
171 |
+
out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
|
172 |
+
mesh.export(out_mesh_path, include_normals=True)
|
173 |
+
out_points_path = os.path.join(output_dir, str(i), "points.ply")
|
174 |
+
glob_dict["point_clouds"][0].export(out_points_path)
|
175 |
+
else:
|
176 |
+
for j in range(len(mesh)):
|
177 |
+
out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
|
178 |
+
mesh[j].export(out_mesh_path, include_normals=True)
|
179 |
+
out_points_path = os.path.join(output_dir, str(i + j), "points.ply")
|
180 |
+
glob_dict["point_clouds"][j].export(out_points_path)
|
spar3d/models/camera.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from spar3d.models.utils import BaseModule
|
8 |
+
|
9 |
+
|
10 |
+
class LinearCameraEmbedder(BaseModule):
|
11 |
+
@dataclass
|
12 |
+
class Config(BaseModule.Config):
|
13 |
+
in_channels: int = 25
|
14 |
+
out_channels: int = 768
|
15 |
+
conditions: List[str] = field(default_factory=list)
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
|
21 |
+
|
22 |
+
def forward(self, **kwargs):
|
23 |
+
cond_tensors = []
|
24 |
+
for cond_name in self.cfg.conditions:
|
25 |
+
assert cond_name in kwargs
|
26 |
+
cond = kwargs[cond_name]
|
27 |
+
# cond in shape (B, Nv, ...)
|
28 |
+
cond_tensors.append(cond.view(*cond.shape[:2], -1))
|
29 |
+
cond_tensor = torch.cat(cond_tensors, dim=-1)
|
30 |
+
assert cond_tensor.shape[-1] == self.cfg.in_channels
|
31 |
+
embedding = self.linear(cond_tensor)
|
32 |
+
return embedding
|
spar3d/models/diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Adapted from: https://github.com/openai/point-e
|
3 |
+
# Licensed under the MIT License
|
4 |
+
# Copyright (c) 2022 OpenAI
|
5 |
+
|
6 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
# of this software and associated documentation files (the "Software"), to deal
|
8 |
+
# in the Software without restriction, including without limitation the rights
|
9 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
# copies of the Software, and to permit persons to whom the Software is
|
11 |
+
# furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
# The above copyright notice and this permission notice shall be included in all
|
14 |
+
# copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
# SOFTWARE.
|
23 |
+
|
24 |
+
# --------------------------------------------------------
|
25 |
+
|
26 |
+
import math
|
27 |
+
from typing import Any, Dict, Iterable, Optional, Sequence, Union
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch as th
|
31 |
+
|
32 |
+
|
33 |
+
def sigmoid_schedule(t, start=-3, end=3, tau=0.6, clip_min=1e-9):
|
34 |
+
def sigmoid(x):
|
35 |
+
return 1 / (1 + np.exp(-x))
|
36 |
+
|
37 |
+
v_start = sigmoid(start / tau)
|
38 |
+
v_end = sigmoid(end / tau)
|
39 |
+
output = sigmoid((t * (end - start) + start) / tau)
|
40 |
+
output = (v_end - output) / (v_end - v_start)
|
41 |
+
return np.clip(output, clip_min, 1.0)
|
42 |
+
|
43 |
+
|
44 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
45 |
+
"""
|
46 |
+
This is the deprecated API for creating beta schedules.
|
47 |
+
|
48 |
+
See get_named_beta_schedule() for the new library of schedules.
|
49 |
+
"""
|
50 |
+
if beta_schedule == "linear":
|
51 |
+
betas = np.linspace(
|
52 |
+
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
raise NotImplementedError(beta_schedule)
|
56 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
57 |
+
return betas
|
58 |
+
|
59 |
+
|
60 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, exp_p=12):
|
61 |
+
"""
|
62 |
+
Get a pre-defined beta schedule for the given name.
|
63 |
+
|
64 |
+
The beta schedule library consists of beta schedules which remain similar
|
65 |
+
in the limit of num_diffusion_timesteps.
|
66 |
+
Beta schedules may be added, but should not be removed or changed once
|
67 |
+
they are committed to maintain backwards compatibility.
|
68 |
+
"""
|
69 |
+
if schedule_name == "linear":
|
70 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
71 |
+
# diffusion steps.
|
72 |
+
scale = 1000 / num_diffusion_timesteps
|
73 |
+
return get_beta_schedule(
|
74 |
+
"linear",
|
75 |
+
beta_start=scale * 0.0001,
|
76 |
+
beta_end=scale * 0.02,
|
77 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
78 |
+
)
|
79 |
+
elif schedule_name == "cosine":
|
80 |
+
return betas_for_alpha_bar(
|
81 |
+
num_diffusion_timesteps,
|
82 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
83 |
+
)
|
84 |
+
elif schedule_name == "sigmoid":
|
85 |
+
# Sigmoid schedule passed through betas_for_alpha_bar
|
86 |
+
return betas_for_alpha_bar(
|
87 |
+
num_diffusion_timesteps, lambda t: sigmoid_schedule(t)
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
91 |
+
|
92 |
+
|
93 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
94 |
+
"""
|
95 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
96 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
97 |
+
|
98 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
99 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
100 |
+
produces the cumulative product of (1-beta) up to that
|
101 |
+
part of the diffusion process.
|
102 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
103 |
+
prevent singularities.
|
104 |
+
"""
|
105 |
+
betas = []
|
106 |
+
for i in range(num_diffusion_timesteps):
|
107 |
+
t1 = i / num_diffusion_timesteps
|
108 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
109 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
110 |
+
return np.array(betas)
|
111 |
+
|
112 |
+
|
113 |
+
def space_timesteps(num_timesteps, section_counts):
|
114 |
+
"""
|
115 |
+
Create a list of timesteps to use from an original diffusion process,
|
116 |
+
given the number of timesteps we want to take from equally-sized portions
|
117 |
+
of the original process.
|
118 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
119 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
120 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
121 |
+
:param num_timesteps: the number of diffusion steps in the original
|
122 |
+
process to divide up.
|
123 |
+
:param section_counts: either a list of numbers, or a string containing
|
124 |
+
comma-separated numbers, indicating the step count
|
125 |
+
per section. As a special case, use "ddimN" where N
|
126 |
+
is a number of steps to use the striding from the
|
127 |
+
DDIM paper.
|
128 |
+
:return: a set of diffusion steps from the original process to use.
|
129 |
+
"""
|
130 |
+
if isinstance(section_counts, str):
|
131 |
+
if section_counts.startswith("ddim"):
|
132 |
+
desired_count = int(section_counts[len("ddim") :])
|
133 |
+
for i in range(1, num_timesteps):
|
134 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
135 |
+
return set(range(0, num_timesteps, i))
|
136 |
+
raise ValueError(
|
137 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
138 |
+
)
|
139 |
+
elif section_counts.startswith("exact"):
|
140 |
+
res = set(int(x) for x in section_counts[len("exact") :].split(","))
|
141 |
+
for x in res:
|
142 |
+
if x < 0 or x >= num_timesteps:
|
143 |
+
raise ValueError(f"timestep out of bounds: {x}")
|
144 |
+
return res
|
145 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
146 |
+
size_per = num_timesteps // len(section_counts)
|
147 |
+
extra = num_timesteps % len(section_counts)
|
148 |
+
start_idx = 0
|
149 |
+
all_steps = []
|
150 |
+
for i, section_count in enumerate(section_counts):
|
151 |
+
size = size_per + (1 if i < extra else 0)
|
152 |
+
if size < section_count:
|
153 |
+
raise ValueError(
|
154 |
+
f"cannot divide section of {size} steps into {section_count}"
|
155 |
+
)
|
156 |
+
if section_count <= 1:
|
157 |
+
frac_stride = 1
|
158 |
+
else:
|
159 |
+
frac_stride = (size - 1) / (section_count - 1)
|
160 |
+
cur_idx = 0.0
|
161 |
+
taken_steps = []
|
162 |
+
for _ in range(section_count):
|
163 |
+
taken_steps.append(start_idx + round(cur_idx))
|
164 |
+
cur_idx += frac_stride
|
165 |
+
all_steps += taken_steps
|
166 |
+
start_idx += size
|
167 |
+
return set(all_steps)
|
168 |
+
|
169 |
+
|
170 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
171 |
+
"""Extract values from a 1-D numpy array for a batch of indices."""
|
172 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
173 |
+
while len(res.shape) < len(broadcast_shape):
|
174 |
+
res = res[..., None]
|
175 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
176 |
+
|
177 |
+
|
178 |
+
class GaussianDiffusion:
|
179 |
+
"""
|
180 |
+
Utilities for sampling from Gaussian diffusion models.
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
*,
|
186 |
+
betas: Sequence[float],
|
187 |
+
model_mean_type: str,
|
188 |
+
model_var_type: str,
|
189 |
+
channel_scales: Optional[np.ndarray] = None,
|
190 |
+
channel_biases: Optional[np.ndarray] = None,
|
191 |
+
):
|
192 |
+
self.model_mean_type = model_mean_type
|
193 |
+
self.model_var_type = model_var_type
|
194 |
+
self.channel_scales = channel_scales
|
195 |
+
self.channel_biases = channel_biases
|
196 |
+
|
197 |
+
# Use float64 for accuracy
|
198 |
+
betas = np.array(betas, dtype=np.float64)
|
199 |
+
self.betas = betas
|
200 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
201 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
202 |
+
|
203 |
+
self.num_timesteps = int(betas.shape[0])
|
204 |
+
|
205 |
+
alphas = 1.0 - betas
|
206 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
207 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
208 |
+
|
209 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
210 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
211 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
212 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
213 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
214 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
215 |
+
self.posterior_variance = (
|
216 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
217 |
+
)
|
218 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
219 |
+
self.posterior_log_variance_clipped = np.log(
|
220 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
221 |
+
)
|
222 |
+
|
223 |
+
self.posterior_mean_coef1 = (
|
224 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
225 |
+
)
|
226 |
+
self.posterior_mean_coef2 = (
|
227 |
+
(1.0 - self.alphas_cumprod_prev)
|
228 |
+
* np.sqrt(alphas)
|
229 |
+
/ (1.0 - self.alphas_cumprod)
|
230 |
+
)
|
231 |
+
|
232 |
+
def scale_channels(self, x: th.Tensor) -> th.Tensor:
|
233 |
+
"""Apply channel-wise scaling."""
|
234 |
+
if self.channel_scales is not None:
|
235 |
+
x = x * th.from_numpy(self.channel_scales).to(x).reshape(
|
236 |
+
[1, -1, *([1] * (len(x.shape) - 2))]
|
237 |
+
)
|
238 |
+
if self.channel_biases is not None:
|
239 |
+
x = x + th.from_numpy(self.channel_biases).to(x).reshape(
|
240 |
+
[1, -1, *([1] * (len(x.shape) - 2))]
|
241 |
+
)
|
242 |
+
return x
|
243 |
+
|
244 |
+
def unscale_channels(self, x: th.Tensor) -> th.Tensor:
|
245 |
+
"""Remove channel-wise scaling."""
|
246 |
+
if self.channel_biases is not None:
|
247 |
+
x = x - th.from_numpy(self.channel_biases).to(x).reshape(
|
248 |
+
[1, -1, *([1] * (len(x.shape) - 2))]
|
249 |
+
)
|
250 |
+
if self.channel_scales is not None:
|
251 |
+
x = x / th.from_numpy(self.channel_scales).to(x).reshape(
|
252 |
+
[1, -1, *([1] * (len(x.shape) - 2))]
|
253 |
+
)
|
254 |
+
return x
|
255 |
+
|
256 |
+
def unscale_out_dict(
|
257 |
+
self, out: Dict[str, Union[th.Tensor, Any]]
|
258 |
+
) -> Dict[str, Union[th.Tensor, Any]]:
|
259 |
+
return {
|
260 |
+
k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v)
|
261 |
+
for k, v in out.items()
|
262 |
+
}
|
263 |
+
|
264 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
265 |
+
"""
|
266 |
+
Compute the mean and variance of the diffusion posterior:
|
267 |
+
|
268 |
+
q(x_{t-1} | x_t, x_0)
|
269 |
+
|
270 |
+
"""
|
271 |
+
assert x_start.shape == x_t.shape
|
272 |
+
posterior_mean = (
|
273 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
274 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
275 |
+
)
|
276 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
277 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
278 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
279 |
+
)
|
280 |
+
assert (
|
281 |
+
posterior_mean.shape[0]
|
282 |
+
== posterior_variance.shape[0]
|
283 |
+
== posterior_log_variance_clipped.shape[0]
|
284 |
+
== x_start.shape[0]
|
285 |
+
)
|
286 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
287 |
+
|
288 |
+
def p_mean_variance(
|
289 |
+
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
|
290 |
+
):
|
291 |
+
"""
|
292 |
+
Apply the model to get p(x_{t-1} | x_t).
|
293 |
+
"""
|
294 |
+
if model_kwargs is None:
|
295 |
+
model_kwargs = {}
|
296 |
+
|
297 |
+
B, C = x.shape[:2]
|
298 |
+
assert t.shape == (B,)
|
299 |
+
|
300 |
+
# Direct prediction of eps
|
301 |
+
model_output = model(x, t, **model_kwargs)
|
302 |
+
if isinstance(model_output, tuple):
|
303 |
+
model_output, prev_latent = model_output
|
304 |
+
model_kwargs["prev_latent"] = prev_latent
|
305 |
+
|
306 |
+
# Convert model output to mean and variance
|
307 |
+
model_variance, model_log_variance = {
|
308 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
309 |
+
# to get a better decoder log likelihood.
|
310 |
+
"fixed_large": (
|
311 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
312 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
313 |
+
),
|
314 |
+
"fixed_small": (
|
315 |
+
self.posterior_variance,
|
316 |
+
self.posterior_log_variance_clipped,
|
317 |
+
),
|
318 |
+
}[self.model_var_type]
|
319 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
320 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
321 |
+
|
322 |
+
def process_xstart(x):
|
323 |
+
if denoised_fn is not None:
|
324 |
+
x = denoised_fn(x)
|
325 |
+
if clip_denoised:
|
326 |
+
x = x.clamp(
|
327 |
+
-self.channel_scales[0] * 0.67, self.channel_scales[0] * 0.67
|
328 |
+
)
|
329 |
+
x[:, 3:] = x[:, 3:].clamp(
|
330 |
+
-self.channel_scales[3] * 0.5, self.channel_scales[3] * 0.5
|
331 |
+
)
|
332 |
+
return x
|
333 |
+
return x
|
334 |
+
|
335 |
+
if self.model_mean_type == "x_prev":
|
336 |
+
pred_xstart = process_xstart(
|
337 |
+
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
338 |
+
)
|
339 |
+
model_mean = model_output
|
340 |
+
elif self.model_mean_type in ["x_start", "epsilon"]:
|
341 |
+
if self.model_mean_type == "x_start":
|
342 |
+
pred_xstart = process_xstart(model_output)
|
343 |
+
else:
|
344 |
+
pred_xstart = process_xstart(
|
345 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
346 |
+
)
|
347 |
+
model_mean, _, _ = self.q_posterior_mean_variance(
|
348 |
+
x_start=pred_xstart, x_t=x, t=t
|
349 |
+
)
|
350 |
+
# print('p_mean_variance:', pred_xstart.min(), pred_xstart.max())
|
351 |
+
else:
|
352 |
+
raise NotImplementedError(self.model_mean_type)
|
353 |
+
|
354 |
+
assert (
|
355 |
+
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
356 |
+
)
|
357 |
+
return {
|
358 |
+
"mean": model_mean,
|
359 |
+
"variance": model_variance,
|
360 |
+
"log_variance": model_log_variance,
|
361 |
+
"pred_xstart": pred_xstart,
|
362 |
+
}
|
363 |
+
|
364 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
365 |
+
assert x_t.shape == eps.shape
|
366 |
+
return (
|
367 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
368 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
369 |
+
)
|
370 |
+
|
371 |
+
def _predict_xstart_from_xprev(self, x_t, t, xprev):
|
372 |
+
assert x_t.shape == xprev.shape
|
373 |
+
return ( # (xprev - coef2*x_t) / coef1
|
374 |
+
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
|
375 |
+
- _extract_into_tensor(
|
376 |
+
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
|
377 |
+
)
|
378 |
+
* x_t
|
379 |
+
)
|
380 |
+
|
381 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
382 |
+
return (
|
383 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
384 |
+
- pred_xstart
|
385 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
386 |
+
|
387 |
+
def ddim_sample_loop_progressive(
|
388 |
+
self,
|
389 |
+
model,
|
390 |
+
shape,
|
391 |
+
noise=None,
|
392 |
+
clip_denoised=True,
|
393 |
+
denoised_fn=None,
|
394 |
+
model_kwargs=None,
|
395 |
+
device=None,
|
396 |
+
progress=False,
|
397 |
+
eta=0.0,
|
398 |
+
):
|
399 |
+
"""
|
400 |
+
Use DDIM to sample from the model and yield intermediate samples.
|
401 |
+
"""
|
402 |
+
if device is None:
|
403 |
+
device = next(model.parameters()).device
|
404 |
+
assert isinstance(shape, (tuple, list))
|
405 |
+
if noise is not None:
|
406 |
+
img = noise
|
407 |
+
else:
|
408 |
+
img = th.randn(*shape, device=device)
|
409 |
+
|
410 |
+
indices = list(range(self.num_timesteps))[::-1]
|
411 |
+
|
412 |
+
if progress:
|
413 |
+
from tqdm.auto import tqdm
|
414 |
+
|
415 |
+
indices = tqdm(indices)
|
416 |
+
|
417 |
+
for i in indices:
|
418 |
+
t = th.tensor([i] * shape[0], device=device)
|
419 |
+
with th.no_grad():
|
420 |
+
out = self.ddim_sample(
|
421 |
+
model,
|
422 |
+
img,
|
423 |
+
t,
|
424 |
+
clip_denoised=clip_denoised,
|
425 |
+
denoised_fn=denoised_fn,
|
426 |
+
model_kwargs=model_kwargs,
|
427 |
+
eta=eta,
|
428 |
+
)
|
429 |
+
yield self.unscale_out_dict(out)
|
430 |
+
img = out["sample"]
|
431 |
+
|
432 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
433 |
+
return (
|
434 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
435 |
+
- pred_xstart
|
436 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
437 |
+
|
438 |
+
def ddim_sample(
|
439 |
+
self,
|
440 |
+
model,
|
441 |
+
x,
|
442 |
+
t,
|
443 |
+
clip_denoised=True,
|
444 |
+
denoised_fn=None,
|
445 |
+
model_kwargs=None,
|
446 |
+
eta=0.0,
|
447 |
+
):
|
448 |
+
"""
|
449 |
+
Sample x_{t-1} from the model using DDIM.
|
450 |
+
"""
|
451 |
+
out = self.p_mean_variance(
|
452 |
+
model,
|
453 |
+
x,
|
454 |
+
t,
|
455 |
+
clip_denoised=clip_denoised,
|
456 |
+
denoised_fn=denoised_fn,
|
457 |
+
model_kwargs=model_kwargs,
|
458 |
+
)
|
459 |
+
|
460 |
+
# Usually our model outputs epsilon, but we re-derive it
|
461 |
+
# in case we used x_start or x_prev prediction.
|
462 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
463 |
+
|
464 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
465 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
466 |
+
sigma = (
|
467 |
+
eta
|
468 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
469 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
470 |
+
)
|
471 |
+
|
472 |
+
# Equation 12.
|
473 |
+
noise = th.randn_like(x)
|
474 |
+
mean_pred = (
|
475 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
476 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
477 |
+
)
|
478 |
+
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
479 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
480 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
481 |
+
|
482 |
+
|
483 |
+
class SpacedDiffusion(GaussianDiffusion):
|
484 |
+
"""
|
485 |
+
A diffusion process which can skip steps in a base diffusion process.
|
486 |
+
"""
|
487 |
+
|
488 |
+
def __init__(self, use_timesteps: Iterable[int], **kwargs):
|
489 |
+
self.use_timesteps = set(use_timesteps)
|
490 |
+
self.timestep_map = []
|
491 |
+
self.original_num_steps = len(kwargs["betas"])
|
492 |
+
|
493 |
+
base_diffusion = GaussianDiffusion(**kwargs)
|
494 |
+
last_alpha_cumprod = 1.0
|
495 |
+
new_betas = []
|
496 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
497 |
+
if i in self.use_timesteps:
|
498 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
499 |
+
last_alpha_cumprod = alpha_cumprod
|
500 |
+
self.timestep_map.append(i)
|
501 |
+
kwargs["betas"] = np.array(new_betas)
|
502 |
+
super().__init__(**kwargs)
|
503 |
+
|
504 |
+
def p_mean_variance(self, model, *args, **kwargs):
|
505 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
506 |
+
|
507 |
+
def _wrap_model(self, model):
|
508 |
+
if isinstance(model, _WrappedModel):
|
509 |
+
return model
|
510 |
+
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
|
511 |
+
|
512 |
+
|
513 |
+
class _WrappedModel:
|
514 |
+
"""Helper class to wrap models for SpacedDiffusion."""
|
515 |
+
|
516 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
517 |
+
self.model = model
|
518 |
+
self.timestep_map = timestep_map
|
519 |
+
self.original_num_steps = original_num_steps
|
520 |
+
|
521 |
+
def __call__(self, x, ts, **kwargs):
|
522 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
523 |
+
new_ts = map_tensor[ts]
|
524 |
+
return self.model(x, new_ts, **kwargs)
|
spar3d/models/diffusion/sampler.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Adapted from: https://github.com/openai/point-e
|
3 |
+
# Licensed under the MIT License
|
4 |
+
# Copyright (c) 2022 OpenAI
|
5 |
+
|
6 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
# of this software and associated documentation files (the "Software"), to deal
|
8 |
+
# in the Software without restriction, including without limitation the rights
|
9 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
# copies of the Software, and to permit persons to whom the Software is
|
11 |
+
# furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
# The above copyright notice and this permission notice shall be included in all
|
14 |
+
# copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
# SOFTWARE.
|
23 |
+
|
24 |
+
# --------------------------------------------------------
|
25 |
+
|
26 |
+
from typing import Dict, Iterator
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
|
31 |
+
from .gaussian_diffusion import GaussianDiffusion
|
32 |
+
|
33 |
+
|
34 |
+
class PointCloudSampler:
|
35 |
+
"""
|
36 |
+
A wrapper around a model that produces conditional sample tensors.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
model: nn.Module,
|
42 |
+
diffusion: GaussianDiffusion,
|
43 |
+
num_points: int,
|
44 |
+
point_dim: int = 3,
|
45 |
+
guidance_scale: float = 3.0,
|
46 |
+
clip_denoised: bool = True,
|
47 |
+
sigma_min: float = 1e-3,
|
48 |
+
sigma_max: float = 120,
|
49 |
+
s_churn: float = 3,
|
50 |
+
):
|
51 |
+
self.model = model
|
52 |
+
self.num_points = num_points
|
53 |
+
self.point_dim = point_dim
|
54 |
+
self.guidance_scale = guidance_scale
|
55 |
+
self.clip_denoised = clip_denoised
|
56 |
+
self.sigma_min = sigma_min
|
57 |
+
self.sigma_max = sigma_max
|
58 |
+
self.s_churn = s_churn
|
59 |
+
|
60 |
+
self.diffusion = diffusion
|
61 |
+
|
62 |
+
def sample_batch_progressive(
|
63 |
+
self,
|
64 |
+
batch_size: int,
|
65 |
+
condition: torch.Tensor,
|
66 |
+
noise=None,
|
67 |
+
device=None,
|
68 |
+
guidance_scale=None,
|
69 |
+
) -> Iterator[Dict[str, torch.Tensor]]:
|
70 |
+
"""
|
71 |
+
Generate samples progressively using classifier-free guidance.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
batch_size: Number of samples to generate
|
75 |
+
condition: Conditioning tensor
|
76 |
+
noise: Optional initial noise tensor
|
77 |
+
device: Device to run on
|
78 |
+
guidance_scale: Optional override for guidance scale
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Iterator of dicts containing intermediate samples
|
82 |
+
"""
|
83 |
+
if guidance_scale is None:
|
84 |
+
guidance_scale = self.guidance_scale
|
85 |
+
|
86 |
+
sample_shape = (batch_size, self.point_dim, self.num_points)
|
87 |
+
|
88 |
+
# Double the batch for classifier-free guidance
|
89 |
+
if guidance_scale != 1 and guidance_scale != 0:
|
90 |
+
condition = torch.cat([condition, torch.zeros_like(condition)], dim=0)
|
91 |
+
if noise is not None:
|
92 |
+
noise = torch.cat([noise, noise], dim=0)
|
93 |
+
model_kwargs = {"condition": condition}
|
94 |
+
|
95 |
+
internal_batch_size = batch_size
|
96 |
+
if guidance_scale != 1 and guidance_scale != 0:
|
97 |
+
model = self._uncond_guide_model(self.model, guidance_scale)
|
98 |
+
internal_batch_size *= 2
|
99 |
+
else:
|
100 |
+
model = self.model
|
101 |
+
|
102 |
+
samples_it = self.diffusion.ddim_sample_loop_progressive(
|
103 |
+
model,
|
104 |
+
shape=(internal_batch_size, *sample_shape[1:]),
|
105 |
+
model_kwargs=model_kwargs,
|
106 |
+
device=device,
|
107 |
+
clip_denoised=self.clip_denoised,
|
108 |
+
noise=noise,
|
109 |
+
)
|
110 |
+
|
111 |
+
for x in samples_it:
|
112 |
+
samples = {
|
113 |
+
"xstart": x["pred_xstart"][:batch_size],
|
114 |
+
"xprev": x["sample"][:batch_size] if "sample" in x else x["x"],
|
115 |
+
}
|
116 |
+
yield samples
|
117 |
+
|
118 |
+
def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module:
|
119 |
+
"""
|
120 |
+
Wraps the model for classifier-free guidance.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def model_fn(x_t, ts, **kwargs):
|
124 |
+
half = x_t[: len(x_t) // 2]
|
125 |
+
combined = torch.cat([half, half], dim=0)
|
126 |
+
model_out = model(combined, ts, **kwargs)
|
127 |
+
|
128 |
+
eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :]
|
129 |
+
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
|
130 |
+
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
|
131 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
132 |
+
return torch.cat([eps, rest], dim=1)
|
133 |
+
|
134 |
+
return model_fn
|
spar3d/models/global_estimator/reni_estimator.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from spar3d.models.illumination.reni.env_map import RENIEnvMap
|
11 |
+
from spar3d.models.utils import BaseModule
|
12 |
+
|
13 |
+
|
14 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
15 |
+
assert d6.shape[-1] == 6, "Input tensor must have shape (..., 6)"
|
16 |
+
|
17 |
+
def proj_u2a(u, a):
|
18 |
+
r"""
|
19 |
+
u: batch x 3
|
20 |
+
a: batch x 3
|
21 |
+
"""
|
22 |
+
inner_prod = torch.sum(u * a, dim=-1, keepdim=True)
|
23 |
+
norm2 = torch.sum(u**2, dim=-1, keepdim=True)
|
24 |
+
norm2 = torch.clamp(norm2, min=1e-8)
|
25 |
+
factor = inner_prod / (norm2 + 1e-10)
|
26 |
+
return factor * u
|
27 |
+
|
28 |
+
x_raw, y_raw = d6[..., :3], d6[..., 3:]
|
29 |
+
|
30 |
+
x = F.normalize(x_raw, dim=-1)
|
31 |
+
y = F.normalize(y_raw - proj_u2a(x, y_raw), dim=-1)
|
32 |
+
z = torch.cross(x, y, dim=-1)
|
33 |
+
|
34 |
+
return torch.stack((x, y, z), dim=-1)
|
35 |
+
|
36 |
+
|
37 |
+
class ReniLatentCodeEstimator(BaseModule):
|
38 |
+
@dataclass
|
39 |
+
class Config(BaseModule.Config):
|
40 |
+
triplane_features: int = 40
|
41 |
+
|
42 |
+
n_layers: int = 5
|
43 |
+
hidden_features: int = 512
|
44 |
+
activation: str = "relu"
|
45 |
+
|
46 |
+
pool: str = "mean"
|
47 |
+
|
48 |
+
reni_env_config: dict = field(default_factory=dict)
|
49 |
+
|
50 |
+
cfg: Config
|
51 |
+
|
52 |
+
def configure(self):
|
53 |
+
layers = []
|
54 |
+
cur_features = self.cfg.triplane_features * 3
|
55 |
+
for _ in range(self.cfg.n_layers):
|
56 |
+
layers.append(
|
57 |
+
nn.Conv2d(
|
58 |
+
cur_features,
|
59 |
+
self.cfg.hidden_features,
|
60 |
+
kernel_size=3,
|
61 |
+
padding=0,
|
62 |
+
stride=2,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
layers.append(self.make_activation(self.cfg.activation))
|
66 |
+
|
67 |
+
cur_features = self.cfg.hidden_features
|
68 |
+
|
69 |
+
self.layers = nn.Sequential(*layers)
|
70 |
+
|
71 |
+
self.reni_env_map = RENIEnvMap(self.cfg.reni_env_config)
|
72 |
+
self.latent_dim = self.reni_env_map.field.latent_dim
|
73 |
+
|
74 |
+
self.fc_latents = nn.Linear(self.cfg.hidden_features, self.latent_dim * 3)
|
75 |
+
nn.init.normal_(self.fc_latents.weight, mean=0.0, std=0.3)
|
76 |
+
|
77 |
+
self.fc_rotations = nn.Linear(self.cfg.hidden_features, 6)
|
78 |
+
nn.init.constant_(self.fc_rotations.bias, 0.0)
|
79 |
+
nn.init.normal_(
|
80 |
+
self.fc_rotations.weight, mean=0.0, std=0.01
|
81 |
+
) # Small variance here
|
82 |
+
|
83 |
+
self.fc_scale = nn.Linear(self.cfg.hidden_features, 1)
|
84 |
+
nn.init.constant_(self.fc_scale.bias, 0.0)
|
85 |
+
nn.init.normal_(self.fc_scale.weight, mean=0.0, std=0.01) # Small variance here
|
86 |
+
|
87 |
+
def make_activation(self, activation):
|
88 |
+
if activation == "relu":
|
89 |
+
return nn.ReLU(inplace=True)
|
90 |
+
elif activation == "silu":
|
91 |
+
return nn.SiLU(inplace=True)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
def forward(
|
96 |
+
self,
|
97 |
+
triplane: Float[Tensor, "B 3 F Ht Wt"],
|
98 |
+
) -> dict[str, Any]:
|
99 |
+
x = self.layers(
|
100 |
+
triplane.reshape(
|
101 |
+
triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
|
102 |
+
)
|
103 |
+
)
|
104 |
+
x = x.mean(dim=[-2, -1])
|
105 |
+
|
106 |
+
latents = self.fc_latents(x).reshape(-1, self.latent_dim, 3)
|
107 |
+
rotations = self.fc_rotations(x)
|
108 |
+
scale = self.fc_scale(x)
|
109 |
+
|
110 |
+
env_map = self.reni_env_map(latents, rotation_6d_to_matrix(rotations), scale)
|
111 |
+
|
112 |
+
return {"illumination": env_map["rgb"]}
|
spar3d/models/illumination/reni/components/film_siren.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/."""
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
def kaiming_leaky_init(m):
|
11 |
+
classname = m.__class__.__name__
|
12 |
+
if classname.find("Linear") != -1:
|
13 |
+
torch.nn.init.kaiming_normal_(
|
14 |
+
m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu"
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def frequency_init(freq):
|
19 |
+
def init(m):
|
20 |
+
with torch.no_grad():
|
21 |
+
if isinstance(m, nn.Linear):
|
22 |
+
num_input = m.weight.size(-1)
|
23 |
+
m.weight.uniform_(
|
24 |
+
-np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq
|
25 |
+
)
|
26 |
+
|
27 |
+
return init
|
28 |
+
|
29 |
+
|
30 |
+
def first_layer_film_sine_init(m):
|
31 |
+
with torch.no_grad():
|
32 |
+
if isinstance(m, nn.Linear):
|
33 |
+
num_input = m.weight.size(-1)
|
34 |
+
m.weight.uniform_(-1 / num_input, 1 / num_input)
|
35 |
+
|
36 |
+
|
37 |
+
class CustomMappingNetwork(nn.Module):
|
38 |
+
def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.network = []
|
42 |
+
|
43 |
+
for _ in range(map_hidden_layers):
|
44 |
+
self.network.append(nn.Linear(in_features, map_hidden_dim))
|
45 |
+
self.network.append(nn.LeakyReLU(0.2, inplace=True))
|
46 |
+
in_features = map_hidden_dim
|
47 |
+
|
48 |
+
self.network.append(nn.Linear(map_hidden_dim, map_output_dim))
|
49 |
+
|
50 |
+
self.network = nn.Sequential(*self.network)
|
51 |
+
|
52 |
+
self.network.apply(kaiming_leaky_init)
|
53 |
+
with torch.no_grad():
|
54 |
+
self.network[-1].weight *= 0.25
|
55 |
+
|
56 |
+
def forward(self, z):
|
57 |
+
frequencies_offsets = self.network(z)
|
58 |
+
frequencies = frequencies_offsets[
|
59 |
+
..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor")
|
60 |
+
]
|
61 |
+
phase_shifts = frequencies_offsets[
|
62 |
+
..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") :
|
63 |
+
]
|
64 |
+
|
65 |
+
return frequencies, phase_shifts
|
66 |
+
|
67 |
+
|
68 |
+
class FiLMLayer(nn.Module):
|
69 |
+
def __init__(self, input_dim, hidden_dim):
|
70 |
+
super().__init__()
|
71 |
+
self.layer = nn.Linear(input_dim, hidden_dim)
|
72 |
+
|
73 |
+
def forward(self, x, freq, phase_shift):
|
74 |
+
x = self.layer(x)
|
75 |
+
freq = freq.expand_as(x)
|
76 |
+
phase_shift = phase_shift.expand_as(x)
|
77 |
+
return torch.sin(freq * x + phase_shift)
|
78 |
+
|
79 |
+
|
80 |
+
class FiLMSiren(nn.Module):
|
81 |
+
"""FiLM Conditioned Siren network."""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
in_dim: int,
|
86 |
+
hidden_layers: int,
|
87 |
+
hidden_features: int,
|
88 |
+
mapping_network_in_dim: int,
|
89 |
+
mapping_network_layers: int,
|
90 |
+
mapping_network_features: int,
|
91 |
+
out_dim: int,
|
92 |
+
outermost_linear: bool = False,
|
93 |
+
out_activation: Optional[nn.Module] = None,
|
94 |
+
) -> None:
|
95 |
+
super().__init__()
|
96 |
+
self.in_dim = in_dim
|
97 |
+
assert self.in_dim > 0
|
98 |
+
self.out_dim = out_dim if out_dim is not None else hidden_features
|
99 |
+
self.hidden_layers = hidden_layers
|
100 |
+
self.hidden_features = hidden_features
|
101 |
+
self.mapping_network_in_dim = mapping_network_in_dim
|
102 |
+
self.mapping_network_layers = mapping_network_layers
|
103 |
+
self.mapping_network_features = mapping_network_features
|
104 |
+
self.outermost_linear = outermost_linear
|
105 |
+
self.out_activation = out_activation
|
106 |
+
|
107 |
+
self.net = nn.ModuleList()
|
108 |
+
|
109 |
+
self.net.append(FiLMLayer(self.in_dim, self.hidden_features))
|
110 |
+
|
111 |
+
for _ in range(self.hidden_layers - 1):
|
112 |
+
self.net.append(FiLMLayer(self.hidden_features, self.hidden_features))
|
113 |
+
|
114 |
+
self.final_layer = None
|
115 |
+
if self.outermost_linear:
|
116 |
+
self.final_layer = nn.Linear(self.hidden_features, self.out_dim)
|
117 |
+
self.final_layer.apply(frequency_init(25))
|
118 |
+
else:
|
119 |
+
final_layer = FiLMLayer(self.hidden_features, self.out_dim)
|
120 |
+
self.net.append(final_layer)
|
121 |
+
|
122 |
+
self.mapping_network = CustomMappingNetwork(
|
123 |
+
in_features=self.mapping_network_in_dim,
|
124 |
+
map_hidden_layers=self.mapping_network_layers,
|
125 |
+
map_hidden_dim=self.mapping_network_features,
|
126 |
+
map_output_dim=(len(self.net)) * self.hidden_features * 2,
|
127 |
+
)
|
128 |
+
|
129 |
+
self.net.apply(frequency_init(25))
|
130 |
+
self.net[0].apply(first_layer_film_sine_init)
|
131 |
+
|
132 |
+
def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts):
|
133 |
+
"""Get conditiional frequencies and phase shifts from mapping network."""
|
134 |
+
frequencies = frequencies * 15 + 30
|
135 |
+
|
136 |
+
for index, layer in enumerate(self.net):
|
137 |
+
start = index * self.hidden_features
|
138 |
+
end = (index + 1) * self.hidden_features
|
139 |
+
x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
|
140 |
+
|
141 |
+
x = self.final_layer(x) if self.final_layer is not None else x
|
142 |
+
output = self.out_activation(x) if self.out_activation is not None else x
|
143 |
+
return output
|
144 |
+
|
145 |
+
def forward(self, x, conditioning_input):
|
146 |
+
"""Forward pass."""
|
147 |
+
frequencies, phase_shifts = self.mapping_network(conditioning_input)
|
148 |
+
return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts)
|
spar3d/models/illumination/reni/components/siren.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Siren MLP https://www.vincentsitzmann.com/siren/"""
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class SineLayer(nn.Module):
|
11 |
+
"""
|
12 |
+
Sine layer for the SIREN network.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self, in_features, out_features, bias=True, is_first=False, omega_0=30.0
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.omega_0 = omega_0
|
20 |
+
self.is_first = is_first
|
21 |
+
|
22 |
+
self.in_features = in_features
|
23 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
24 |
+
|
25 |
+
self.init_weights()
|
26 |
+
|
27 |
+
def init_weights(self):
|
28 |
+
with torch.no_grad():
|
29 |
+
if self.is_first:
|
30 |
+
self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
|
31 |
+
else:
|
32 |
+
self.linear.weight.uniform_(
|
33 |
+
-np.sqrt(6 / self.in_features) / self.omega_0,
|
34 |
+
np.sqrt(6 / self.in_features) / self.omega_0,
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return torch.sin(self.omega_0 * self.linear(x))
|
39 |
+
|
40 |
+
|
41 |
+
class Siren(nn.Module):
|
42 |
+
"""Siren network.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
in_dim: Input layer dimension
|
46 |
+
num_layers: Number of network layers
|
47 |
+
layer_width: Width of each MLP layer
|
48 |
+
out_dim: Output layer dimension. Uses layer_width if None.
|
49 |
+
activation: intermediate layer activation function.
|
50 |
+
out_activation: output activation function.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
in_dim: int,
|
56 |
+
hidden_layers: int,
|
57 |
+
hidden_features: int,
|
58 |
+
out_dim: Optional[int] = None,
|
59 |
+
outermost_linear: bool = False,
|
60 |
+
first_omega_0: float = 30,
|
61 |
+
hidden_omega_0: float = 30,
|
62 |
+
out_activation: Optional[nn.Module] = None,
|
63 |
+
) -> None:
|
64 |
+
super().__init__()
|
65 |
+
self.in_dim = in_dim
|
66 |
+
assert self.in_dim > 0
|
67 |
+
self.out_dim = out_dim if out_dim is not None else hidden_features
|
68 |
+
self.outermost_linear = outermost_linear
|
69 |
+
self.first_omega_0 = first_omega_0
|
70 |
+
self.hidden_omega_0 = hidden_omega_0
|
71 |
+
self.hidden_layers = hidden_layers
|
72 |
+
self.layer_width = hidden_features
|
73 |
+
self.out_activation = out_activation
|
74 |
+
|
75 |
+
self.net = []
|
76 |
+
self.net.append(
|
77 |
+
SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0)
|
78 |
+
)
|
79 |
+
|
80 |
+
for _ in range(hidden_layers):
|
81 |
+
self.net.append(
|
82 |
+
SineLayer(
|
83 |
+
hidden_features,
|
84 |
+
hidden_features,
|
85 |
+
is_first=False,
|
86 |
+
omega_0=hidden_omega_0,
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
if outermost_linear:
|
91 |
+
final_layer = nn.Linear(hidden_features, self.out_dim)
|
92 |
+
|
93 |
+
with torch.no_grad():
|
94 |
+
final_layer.weight.uniform_(
|
95 |
+
-np.sqrt(6 / hidden_features) / hidden_omega_0,
|
96 |
+
np.sqrt(6 / hidden_features) / hidden_omega_0,
|
97 |
+
)
|
98 |
+
|
99 |
+
self.net.append(final_layer)
|
100 |
+
else:
|
101 |
+
self.net.append(
|
102 |
+
SineLayer(
|
103 |
+
hidden_features,
|
104 |
+
self.out_dim,
|
105 |
+
is_first=False,
|
106 |
+
omega_0=hidden_omega_0,
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
if self.out_activation is not None:
|
111 |
+
self.net.append(self.out_activation)
|
112 |
+
|
113 |
+
self.net = nn.Sequential(*self.net)
|
114 |
+
|
115 |
+
def forward(self, model_input):
|
116 |
+
"""Forward pass through the network"""
|
117 |
+
output = self.net(model_input)
|
118 |
+
return output
|
spar3d/models/illumination/reni/components/transformer_decoder.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
class MultiHeadAttention(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
direction_input_dim: int,
|
11 |
+
conditioning_input_dim: int,
|
12 |
+
latent_dim: int,
|
13 |
+
num_heads: int,
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Multi-Head Attention module.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
direction_input_dim (int): The input dimension of the directional input.
|
20 |
+
conditioning_input_dim (int): The input dimension of the conditioning input.
|
21 |
+
latent_dim (int): The latent dimension of the module.
|
22 |
+
num_heads (int): The number of heads to use in the attention mechanism.
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads"
|
26 |
+
self.num_heads = num_heads
|
27 |
+
self.head_dim = latent_dim // num_heads
|
28 |
+
self.scale = self.head_dim**-0.5
|
29 |
+
|
30 |
+
self.query = nn.Linear(direction_input_dim, latent_dim)
|
31 |
+
self.key = nn.Linear(conditioning_input_dim, latent_dim)
|
32 |
+
self.value = nn.Linear(conditioning_input_dim, latent_dim)
|
33 |
+
self.fc_out = nn.Linear(latent_dim, latent_dim)
|
34 |
+
|
35 |
+
def forward(
|
36 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
37 |
+
) -> torch.Tensor:
|
38 |
+
"""
|
39 |
+
Forward pass of the Multi-Head Attention module.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
query (torch.Tensor): The directional input tensor.
|
43 |
+
key (torch.Tensor): The conditioning input tensor for the keys.
|
44 |
+
value (torch.Tensor): The conditioning input tensor for the values.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
torch.Tensor: The output tensor of the Multi-Head Attention module.
|
48 |
+
"""
|
49 |
+
batch_size = query.size(0)
|
50 |
+
|
51 |
+
Q = (
|
52 |
+
self.query(query)
|
53 |
+
.view(batch_size, -1, self.num_heads, self.head_dim)
|
54 |
+
.transpose(1, 2)
|
55 |
+
)
|
56 |
+
K = (
|
57 |
+
self.key(key)
|
58 |
+
.view(batch_size, -1, self.num_heads, self.head_dim)
|
59 |
+
.transpose(1, 2)
|
60 |
+
)
|
61 |
+
V = (
|
62 |
+
self.value(value)
|
63 |
+
.view(batch_size, -1, self.num_heads, self.head_dim)
|
64 |
+
.transpose(1, 2)
|
65 |
+
)
|
66 |
+
|
67 |
+
attention = (
|
68 |
+
torch.einsum("bnqk,bnkh->bnqh", [Q, K.transpose(-2, -1)]) * self.scale
|
69 |
+
)
|
70 |
+
attention = torch.softmax(attention, dim=-1)
|
71 |
+
|
72 |
+
out = torch.einsum("bnqh,bnhv->bnqv", [attention, V])
|
73 |
+
out = (
|
74 |
+
out.transpose(1, 2)
|
75 |
+
.contiguous()
|
76 |
+
.view(batch_size, -1, self.num_heads * self.head_dim)
|
77 |
+
)
|
78 |
+
|
79 |
+
out = self.fc_out(out).squeeze(1)
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class AttentionLayer(nn.Module):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
direction_input_dim: int,
|
87 |
+
conditioning_input_dim: int,
|
88 |
+
latent_dim: int,
|
89 |
+
num_heads: int,
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Attention Layer module.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
direction_input_dim (int): The input dimension of the directional input.
|
96 |
+
conditioning_input_dim (int): The input dimension of the conditioning input.
|
97 |
+
latent_dim (int): The latent dimension of the module.
|
98 |
+
num_heads (int): The number of heads to use in the attention mechanism.
|
99 |
+
"""
|
100 |
+
super().__init__()
|
101 |
+
self.mha = MultiHeadAttention(
|
102 |
+
direction_input_dim, conditioning_input_dim, latent_dim, num_heads
|
103 |
+
)
|
104 |
+
self.norm1 = nn.LayerNorm(latent_dim)
|
105 |
+
self.norm2 = nn.LayerNorm(latent_dim)
|
106 |
+
self.fc = nn.Sequential(
|
107 |
+
nn.Linear(latent_dim, latent_dim),
|
108 |
+
nn.ReLU(),
|
109 |
+
nn.Linear(latent_dim, latent_dim),
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self, directional_input: torch.Tensor, conditioning_input: torch.Tensor
|
114 |
+
) -> torch.Tensor:
|
115 |
+
"""
|
116 |
+
Forward pass of the Attention Layer module.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
directional_input (torch.Tensor): The directional input tensor.
|
120 |
+
conditioning_input (torch.Tensor): The conditioning input tensor.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
torch.Tensor: The output tensor of the Attention Layer module.
|
124 |
+
"""
|
125 |
+
attn_output = self.mha(
|
126 |
+
directional_input, conditioning_input, conditioning_input
|
127 |
+
)
|
128 |
+
out1 = self.norm1(attn_output + directional_input)
|
129 |
+
fc_output = self.fc(out1)
|
130 |
+
out2 = self.norm2(fc_output + out1)
|
131 |
+
return out2
|
132 |
+
|
133 |
+
|
134 |
+
class Decoder(nn.Module):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
in_dim: int,
|
138 |
+
conditioning_input_dim: int,
|
139 |
+
hidden_features: int,
|
140 |
+
num_heads: int,
|
141 |
+
num_layers: int,
|
142 |
+
out_activation: Optional[nn.Module],
|
143 |
+
):
|
144 |
+
"""
|
145 |
+
Decoder module.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
in_dim (int): The input dimension of the module.
|
149 |
+
conditioning_input_dim (int): The input dimension of the conditioning input.
|
150 |
+
hidden_features (int): The number of hidden features in the module.
|
151 |
+
num_heads (int): The number of heads to use in the attention mechanism.
|
152 |
+
num_layers (int): The number of layers in the module.
|
153 |
+
out_activation (nn.Module): The activation function to use on the output tensor.
|
154 |
+
"""
|
155 |
+
super().__init__()
|
156 |
+
self.residual_projection = nn.Linear(
|
157 |
+
in_dim, hidden_features
|
158 |
+
) # projection for residual connection
|
159 |
+
self.layers = nn.ModuleList(
|
160 |
+
[
|
161 |
+
AttentionLayer(
|
162 |
+
hidden_features, conditioning_input_dim, hidden_features, num_heads
|
163 |
+
)
|
164 |
+
for i in range(num_layers)
|
165 |
+
]
|
166 |
+
)
|
167 |
+
self.fc = nn.Linear(hidden_features, 3) # 3 for RGB
|
168 |
+
self.out_activation = out_activation
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self, x: torch.Tensor, conditioning_input: torch.Tensor
|
172 |
+
) -> torch.Tensor:
|
173 |
+
"""
|
174 |
+
Forward pass of the Decoder module.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
x (torch.Tensor): The input tensor.
|
178 |
+
conditioning_input (torch.Tensor): The conditioning input tensor.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
torch.Tensor: The output tensor of the Decoder module.
|
182 |
+
"""
|
183 |
+
x = self.residual_projection(x)
|
184 |
+
for layer in self.layers:
|
185 |
+
x = layer(x, conditioning_input)
|
186 |
+
x = self.fc(x)
|
187 |
+
if self.out_activation is not None:
|
188 |
+
x = self.out_activation(x)
|
189 |
+
return x
|
spar3d/models/illumination/reni/components/vn_layers.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Phil Wang
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
"""All code taken from https://github.com/lucidrains/VN-transformer"""
|
24 |
+
|
25 |
+
from collections import namedtuple
|
26 |
+
from functools import wraps
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from einops import rearrange, reduce
|
31 |
+
from einops.layers.torch import Rearrange
|
32 |
+
from packaging import version
|
33 |
+
from torch import einsum, nn
|
34 |
+
|
35 |
+
# constants
|
36 |
+
|
37 |
+
FlashAttentionConfig = namedtuple(
|
38 |
+
"FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
|
39 |
+
)
|
40 |
+
|
41 |
+
# helpers
|
42 |
+
|
43 |
+
|
44 |
+
def exists(val):
|
45 |
+
return val is not None
|
46 |
+
|
47 |
+
|
48 |
+
def once(fn):
|
49 |
+
called = False
|
50 |
+
|
51 |
+
@wraps(fn)
|
52 |
+
def inner(x):
|
53 |
+
nonlocal called
|
54 |
+
if called:
|
55 |
+
return
|
56 |
+
called = True
|
57 |
+
return fn(x)
|
58 |
+
|
59 |
+
return inner
|
60 |
+
|
61 |
+
|
62 |
+
print_once = once(print)
|
63 |
+
|
64 |
+
# main class
|
65 |
+
|
66 |
+
|
67 |
+
class Attend(nn.Module):
|
68 |
+
def __init__(self, dropout=0.0, flash=False, l2_dist=False):
|
69 |
+
super().__init__()
|
70 |
+
assert not (
|
71 |
+
flash and l2_dist
|
72 |
+
), "flash attention is not compatible with l2 distance"
|
73 |
+
self.l2_dist = l2_dist
|
74 |
+
|
75 |
+
self.dropout = dropout
|
76 |
+
self.attn_dropout = nn.Dropout(dropout)
|
77 |
+
|
78 |
+
self.flash = flash
|
79 |
+
assert not (
|
80 |
+
flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
81 |
+
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
82 |
+
|
83 |
+
# determine efficient attention configs for cuda and cpu
|
84 |
+
|
85 |
+
self.cpu_config = FlashAttentionConfig(True, True, True)
|
86 |
+
self.cuda_config = None
|
87 |
+
|
88 |
+
if not torch.cuda.is_available() or not flash:
|
89 |
+
return
|
90 |
+
|
91 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
92 |
+
|
93 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
94 |
+
print_once(
|
95 |
+
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
96 |
+
)
|
97 |
+
self.cuda_config = FlashAttentionConfig(True, False, False)
|
98 |
+
else:
|
99 |
+
print_once(
|
100 |
+
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
101 |
+
)
|
102 |
+
self.cuda_config = FlashAttentionConfig(False, True, True)
|
103 |
+
|
104 |
+
def flash_attn(self, q, k, v, mask=None):
|
105 |
+
_, heads, q_len, _, _, is_cuda = (
|
106 |
+
*q.shape,
|
107 |
+
k.shape[-2],
|
108 |
+
q.is_cuda,
|
109 |
+
)
|
110 |
+
|
111 |
+
# Check if mask exists and expand to compatible shape
|
112 |
+
# The mask is B L, so it would have to be expanded to B H N L
|
113 |
+
|
114 |
+
if exists(mask):
|
115 |
+
mask = mask.expand(-1, heads, q_len, -1)
|
116 |
+
|
117 |
+
# Check if there is a compatible device for flash attention
|
118 |
+
|
119 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
120 |
+
|
121 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
|
122 |
+
|
123 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
124 |
+
out = F.scaled_dot_product_attention(
|
125 |
+
q,
|
126 |
+
k,
|
127 |
+
v,
|
128 |
+
attn_mask=mask,
|
129 |
+
dropout_p=self.dropout if self.training else 0.0,
|
130 |
+
)
|
131 |
+
|
132 |
+
return out
|
133 |
+
|
134 |
+
def forward(self, q, k, v, mask=None):
|
135 |
+
"""
|
136 |
+
einstein notation
|
137 |
+
b - batch
|
138 |
+
h - heads
|
139 |
+
n, i, j - sequence length (base sequence length, source, target)
|
140 |
+
d - feature dimension
|
141 |
+
"""
|
142 |
+
scale = q.shape[-1] ** -0.5
|
143 |
+
|
144 |
+
if exists(mask) and mask.ndim != 4:
|
145 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
146 |
+
|
147 |
+
if self.flash:
|
148 |
+
return self.flash_attn(q, k, v, mask=mask)
|
149 |
+
|
150 |
+
# similarity
|
151 |
+
|
152 |
+
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
|
153 |
+
|
154 |
+
# l2 distance
|
155 |
+
|
156 |
+
if self.l2_dist:
|
157 |
+
# -cdist squared == (-q^2 + 2qk - k^2)
|
158 |
+
# so simply work off the qk above
|
159 |
+
q_squared = reduce(q**2, "b h i d -> b h i 1", "sum")
|
160 |
+
k_squared = reduce(k**2, "b h j d -> b h 1 j", "sum")
|
161 |
+
sim = sim * 2 - q_squared - k_squared
|
162 |
+
|
163 |
+
# key padding mask
|
164 |
+
|
165 |
+
if exists(mask):
|
166 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
167 |
+
|
168 |
+
# attention
|
169 |
+
|
170 |
+
attn = sim.softmax(dim=-1)
|
171 |
+
attn = self.attn_dropout(attn)
|
172 |
+
|
173 |
+
# aggregate values
|
174 |
+
|
175 |
+
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
176 |
+
|
177 |
+
return out
|
178 |
+
|
179 |
+
|
180 |
+
# helper
|
181 |
+
|
182 |
+
|
183 |
+
def exists(val): # noqa: F811
|
184 |
+
return val is not None
|
185 |
+
|
186 |
+
|
187 |
+
def default(val, d):
|
188 |
+
return val if exists(val) else d
|
189 |
+
|
190 |
+
|
191 |
+
def inner_dot_product(x, y, *, dim=-1, keepdim=True):
|
192 |
+
return (x * y).sum(dim=dim, keepdim=keepdim)
|
193 |
+
|
194 |
+
|
195 |
+
# layernorm
|
196 |
+
|
197 |
+
|
198 |
+
class LayerNorm(nn.Module):
|
199 |
+
def __init__(self, dim):
|
200 |
+
super().__init__()
|
201 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
202 |
+
self.register_buffer("beta", torch.zeros(dim))
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
206 |
+
|
207 |
+
|
208 |
+
# equivariant modules
|
209 |
+
|
210 |
+
|
211 |
+
class VNLinear(nn.Module):
|
212 |
+
def __init__(self, dim_in, dim_out, bias_epsilon=0.0):
|
213 |
+
super().__init__()
|
214 |
+
self.weight = nn.Parameter(torch.randn(dim_out, dim_in))
|
215 |
+
|
216 |
+
self.bias = None
|
217 |
+
self.bias_epsilon = bias_epsilon
|
218 |
+
|
219 |
+
# in this paper, they propose going for quasi-equivariance with a small bias, controllable with epsilon, which they claim lead to better stability and results
|
220 |
+
|
221 |
+
if bias_epsilon > 0.0:
|
222 |
+
self.bias = nn.Parameter(torch.randn(dim_out))
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
out = einsum("... i c, o i -> ... o c", x, self.weight)
|
226 |
+
|
227 |
+
if exists(self.bias):
|
228 |
+
bias = F.normalize(self.bias, dim=-1) * self.bias_epsilon
|
229 |
+
out = out + rearrange(bias, "... -> ... 1")
|
230 |
+
|
231 |
+
return out
|
232 |
+
|
233 |
+
|
234 |
+
class VNReLU(nn.Module):
|
235 |
+
def __init__(self, dim, eps=1e-6):
|
236 |
+
super().__init__()
|
237 |
+
self.eps = eps
|
238 |
+
self.W = nn.Parameter(torch.randn(dim, dim))
|
239 |
+
self.U = nn.Parameter(torch.randn(dim, dim))
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
q = einsum("... i c, o i -> ... o c", x, self.W)
|
243 |
+
k = einsum("... i c, o i -> ... o c", x, self.U)
|
244 |
+
|
245 |
+
qk = inner_dot_product(q, k)
|
246 |
+
|
247 |
+
k_norm = k.norm(dim=-1, keepdim=True).clamp(min=self.eps)
|
248 |
+
q_projected_on_k = q - inner_dot_product(q, k / k_norm) * k
|
249 |
+
|
250 |
+
out = torch.where(qk >= 0.0, q, q_projected_on_k)
|
251 |
+
|
252 |
+
return out
|
253 |
+
|
254 |
+
|
255 |
+
class VNAttention(nn.Module):
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
dim,
|
259 |
+
dim_head=64,
|
260 |
+
heads=8,
|
261 |
+
dim_coor=3,
|
262 |
+
bias_epsilon=0.0,
|
263 |
+
l2_dist_attn=False,
|
264 |
+
flash=False,
|
265 |
+
num_latents=None, # setting this would enable perceiver-like cross attention from latents to sequence, with the latents derived from VNWeightedPool
|
266 |
+
):
|
267 |
+
super().__init__()
|
268 |
+
assert not (
|
269 |
+
l2_dist_attn and flash
|
270 |
+
), "l2 distance attention is not compatible with flash attention"
|
271 |
+
|
272 |
+
self.scale = (dim_coor * dim_head) ** -0.5
|
273 |
+
dim_inner = dim_head * heads
|
274 |
+
self.heads = heads
|
275 |
+
|
276 |
+
self.to_q_input = None
|
277 |
+
if exists(num_latents):
|
278 |
+
self.to_q_input = VNWeightedPool(
|
279 |
+
dim, num_pooled_tokens=num_latents, squeeze_out_pooled_dim=False
|
280 |
+
)
|
281 |
+
|
282 |
+
self.to_q = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
|
283 |
+
self.to_k = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
|
284 |
+
self.to_v = VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon)
|
285 |
+
self.to_out = VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon)
|
286 |
+
|
287 |
+
if l2_dist_attn and not exists(num_latents):
|
288 |
+
# tied queries and keys for l2 distance attention, and not perceiver-like attention
|
289 |
+
self.to_k = self.to_q
|
290 |
+
|
291 |
+
self.attend = Attend(flash=flash, l2_dist=l2_dist_attn)
|
292 |
+
|
293 |
+
def forward(self, x, mask=None):
|
294 |
+
"""
|
295 |
+
einstein notation
|
296 |
+
b - batch
|
297 |
+
n - sequence
|
298 |
+
h - heads
|
299 |
+
d - feature dimension (channels)
|
300 |
+
c - coordinate dimension (3 for 3d space)
|
301 |
+
i - source sequence dimension
|
302 |
+
j - target sequence dimension
|
303 |
+
"""
|
304 |
+
|
305 |
+
c = x.shape[-1]
|
306 |
+
|
307 |
+
if exists(self.to_q_input):
|
308 |
+
q_input = self.to_q_input(x, mask=mask)
|
309 |
+
else:
|
310 |
+
q_input = x
|
311 |
+
|
312 |
+
q, k, v = self.to_q(q_input), self.to_k(x), self.to_v(x)
|
313 |
+
q, k, v = map(
|
314 |
+
lambda t: rearrange(t, "b n (h d) c -> b h n (d c)", h=self.heads),
|
315 |
+
(q, k, v),
|
316 |
+
)
|
317 |
+
|
318 |
+
out = self.attend(q, k, v, mask=mask)
|
319 |
+
|
320 |
+
out = rearrange(out, "b h n (d c) -> b n (h d) c", c=c)
|
321 |
+
return self.to_out(out)
|
322 |
+
|
323 |
+
|
324 |
+
def VNFeedForward(dim, mult=4, bias_epsilon=0.0):
|
325 |
+
dim_inner = int(dim * mult)
|
326 |
+
return nn.Sequential(
|
327 |
+
VNLinear(dim, dim_inner, bias_epsilon=bias_epsilon),
|
328 |
+
VNReLU(dim_inner),
|
329 |
+
VNLinear(dim_inner, dim, bias_epsilon=bias_epsilon),
|
330 |
+
)
|
331 |
+
|
332 |
+
|
333 |
+
class VNLayerNorm(nn.Module):
|
334 |
+
def __init__(self, dim, eps=1e-6):
|
335 |
+
super().__init__()
|
336 |
+
self.eps = eps
|
337 |
+
self.ln = LayerNorm(dim)
|
338 |
+
|
339 |
+
def forward(self, x):
|
340 |
+
norms = x.norm(dim=-1)
|
341 |
+
x = x / rearrange(norms.clamp(min=self.eps), "... -> ... 1")
|
342 |
+
ln_out = self.ln(norms)
|
343 |
+
return x * rearrange(ln_out, "... -> ... 1")
|
344 |
+
|
345 |
+
|
346 |
+
class VNWeightedPool(nn.Module):
|
347 |
+
def __init__(
|
348 |
+
self, dim, dim_out=None, num_pooled_tokens=1, squeeze_out_pooled_dim=True
|
349 |
+
):
|
350 |
+
super().__init__()
|
351 |
+
dim_out = default(dim_out, dim)
|
352 |
+
self.weight = nn.Parameter(torch.randn(num_pooled_tokens, dim, dim_out))
|
353 |
+
self.squeeze_out_pooled_dim = num_pooled_tokens == 1 and squeeze_out_pooled_dim
|
354 |
+
|
355 |
+
def forward(self, x, mask=None):
|
356 |
+
if exists(mask):
|
357 |
+
mask = rearrange(mask, "b n -> b n 1 1")
|
358 |
+
x = x.masked_fill(~mask, 0.0)
|
359 |
+
numer = reduce(x, "b n d c -> b d c", "sum")
|
360 |
+
denom = mask.sum(dim=1)
|
361 |
+
mean_pooled = numer / denom.clamp(min=1e-6)
|
362 |
+
else:
|
363 |
+
mean_pooled = reduce(x, "b n d c -> b d c", "mean")
|
364 |
+
|
365 |
+
out = einsum("b d c, m d e -> b m e c", mean_pooled, self.weight)
|
366 |
+
|
367 |
+
if not self.squeeze_out_pooled_dim:
|
368 |
+
return out
|
369 |
+
|
370 |
+
out = rearrange(out, "b 1 d c -> b d c")
|
371 |
+
return out
|
372 |
+
|
373 |
+
|
374 |
+
# equivariant VN transformer encoder
|
375 |
+
|
376 |
+
|
377 |
+
class VNTransformerEncoder(nn.Module):
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
dim,
|
381 |
+
*,
|
382 |
+
depth,
|
383 |
+
dim_head=64,
|
384 |
+
heads=8,
|
385 |
+
dim_coor=3,
|
386 |
+
ff_mult=4,
|
387 |
+
final_norm=False,
|
388 |
+
bias_epsilon=0.0,
|
389 |
+
l2_dist_attn=False,
|
390 |
+
flash_attn=False,
|
391 |
+
):
|
392 |
+
super().__init__()
|
393 |
+
self.dim = dim
|
394 |
+
self.dim_coor = dim_coor
|
395 |
+
|
396 |
+
self.layers = nn.ModuleList([])
|
397 |
+
|
398 |
+
for _ in range(depth):
|
399 |
+
self.layers.append(
|
400 |
+
nn.ModuleList(
|
401 |
+
[
|
402 |
+
VNAttention(
|
403 |
+
dim=dim,
|
404 |
+
dim_head=dim_head,
|
405 |
+
heads=heads,
|
406 |
+
bias_epsilon=bias_epsilon,
|
407 |
+
l2_dist_attn=l2_dist_attn,
|
408 |
+
flash=flash_attn,
|
409 |
+
),
|
410 |
+
VNLayerNorm(dim),
|
411 |
+
VNFeedForward(dim=dim, mult=ff_mult, bias_epsilon=bias_epsilon),
|
412 |
+
VNLayerNorm(dim),
|
413 |
+
]
|
414 |
+
)
|
415 |
+
)
|
416 |
+
|
417 |
+
self.norm = VNLayerNorm(dim) if final_norm else nn.Identity()
|
418 |
+
|
419 |
+
def forward(self, x, mask=None):
|
420 |
+
*_, d, c = x.shape
|
421 |
+
|
422 |
+
assert (
|
423 |
+
x.ndim == 4 and d == self.dim and c == self.dim_coor
|
424 |
+
), "input needs to be in the shape of (batch, seq, dim ({self.dim}), coordinate dim ({self.dim_coor}))"
|
425 |
+
|
426 |
+
for attn, attn_post_ln, ff, ff_post_ln in self.layers:
|
427 |
+
x = attn_post_ln(attn(x, mask=mask)) + x
|
428 |
+
x = ff_post_ln(ff(x)) + x
|
429 |
+
|
430 |
+
return self.norm(x)
|
431 |
+
|
432 |
+
|
433 |
+
# invariant layers
|
434 |
+
|
435 |
+
|
436 |
+
class VNInvariant(nn.Module):
|
437 |
+
def __init__(
|
438 |
+
self,
|
439 |
+
dim,
|
440 |
+
dim_coor=3,
|
441 |
+
):
|
442 |
+
super().__init__()
|
443 |
+
self.mlp = nn.Sequential(
|
444 |
+
VNLinear(dim, dim_coor), VNReLU(dim_coor), Rearrange("... d e -> ... e d")
|
445 |
+
)
|
446 |
+
|
447 |
+
def forward(self, x):
|
448 |
+
return einsum("b n d i, b n i o -> b n o", x, self.mlp(x))
|
449 |
+
|
450 |
+
|
451 |
+
# main class
|
452 |
+
|
453 |
+
|
454 |
+
class VNTransformer(nn.Module):
|
455 |
+
def __init__(
|
456 |
+
self,
|
457 |
+
*,
|
458 |
+
dim,
|
459 |
+
depth,
|
460 |
+
num_tokens=None,
|
461 |
+
dim_feat=None,
|
462 |
+
dim_head=64,
|
463 |
+
heads=8,
|
464 |
+
dim_coor=3,
|
465 |
+
reduce_dim_out=True,
|
466 |
+
bias_epsilon=0.0,
|
467 |
+
l2_dist_attn=False,
|
468 |
+
flash_attn=False,
|
469 |
+
translation_equivariance=False,
|
470 |
+
translation_invariant=False,
|
471 |
+
):
|
472 |
+
super().__init__()
|
473 |
+
self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
|
474 |
+
|
475 |
+
dim_feat = default(dim_feat, 0)
|
476 |
+
self.dim_feat = dim_feat
|
477 |
+
self.dim_coor_total = dim_coor + dim_feat
|
478 |
+
|
479 |
+
assert (int(translation_equivariance) + int(translation_invariant)) <= 1
|
480 |
+
self.translation_equivariance = translation_equivariance
|
481 |
+
self.translation_invariant = translation_invariant
|
482 |
+
|
483 |
+
self.vn_proj_in = nn.Sequential(
|
484 |
+
Rearrange("... c -> ... 1 c"), VNLinear(1, dim, bias_epsilon=bias_epsilon)
|
485 |
+
)
|
486 |
+
|
487 |
+
self.encoder = VNTransformerEncoder(
|
488 |
+
dim=dim,
|
489 |
+
depth=depth,
|
490 |
+
dim_head=dim_head,
|
491 |
+
heads=heads,
|
492 |
+
bias_epsilon=bias_epsilon,
|
493 |
+
dim_coor=self.dim_coor_total,
|
494 |
+
l2_dist_attn=l2_dist_attn,
|
495 |
+
flash_attn=flash_attn,
|
496 |
+
)
|
497 |
+
|
498 |
+
if reduce_dim_out:
|
499 |
+
self.vn_proj_out = nn.Sequential(
|
500 |
+
VNLayerNorm(dim),
|
501 |
+
VNLinear(dim, 1, bias_epsilon=bias_epsilon),
|
502 |
+
Rearrange("... 1 c -> ... c"),
|
503 |
+
)
|
504 |
+
else:
|
505 |
+
self.vn_proj_out = nn.Identity()
|
506 |
+
|
507 |
+
def forward(
|
508 |
+
self, coors, *, feats=None, mask=None, return_concatted_coors_and_feats=False
|
509 |
+
):
|
510 |
+
if self.translation_equivariance or self.translation_invariant:
|
511 |
+
coors_mean = reduce(coors, "... c -> c", "mean")
|
512 |
+
coors = coors - coors_mean
|
513 |
+
|
514 |
+
x = coors # [batch, num_points, 3]
|
515 |
+
|
516 |
+
if exists(feats):
|
517 |
+
if feats.dtype == torch.long:
|
518 |
+
assert exists(
|
519 |
+
self.token_emb
|
520 |
+
), "num_tokens must be given to the VNTransformer (to build the Embedding), if the features are to be given as indices"
|
521 |
+
feats = self.token_emb(feats)
|
522 |
+
|
523 |
+
assert (
|
524 |
+
feats.shape[-1] == self.dim_feat
|
525 |
+
), f"dim_feat should be set to {feats.shape[-1]}"
|
526 |
+
x = torch.cat((x, feats), dim=-1) # [batch, num_points, 3 + dim_feat]
|
527 |
+
|
528 |
+
assert x.shape[-1] == self.dim_coor_total
|
529 |
+
|
530 |
+
x = self.vn_proj_in(x) # [batch, num_points, hidden_dim, 3 + dim_feat]
|
531 |
+
x = self.encoder(x, mask=mask) # [batch, num_points, hidden_dim, 3 + dim_feat]
|
532 |
+
x = self.vn_proj_out(x) # [batch, num_points, 3 + dim_feat]
|
533 |
+
|
534 |
+
coors_out, feats_out = (
|
535 |
+
x[..., :3],
|
536 |
+
x[..., 3:],
|
537 |
+
) # [batch, num_points, 3], [batch, num_points, dim_feat]
|
538 |
+
|
539 |
+
if self.translation_equivariance:
|
540 |
+
coors_out = coors_out + coors_mean
|
541 |
+
|
542 |
+
if not exists(feats):
|
543 |
+
return coors_out
|
544 |
+
|
545 |
+
if return_concatted_coors_and_feats:
|
546 |
+
return torch.cat((coors_out, feats_out), dim=-1)
|
547 |
+
|
548 |
+
return coors_out, feats_out
|
spar3d/models/illumination/reni/env_map.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Dict, List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from jaxtyping import Float
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from spar3d.models.utils import BaseModule
|
9 |
+
|
10 |
+
from .field import RENIField
|
11 |
+
|
12 |
+
|
13 |
+
def _direction_from_coordinate(
|
14 |
+
coordinate: Float[Tensor, "*B 2"],
|
15 |
+
) -> Float[Tensor, "*B 3"]:
|
16 |
+
# OpenGL Convention
|
17 |
+
# +X Right
|
18 |
+
# +Y Up
|
19 |
+
# +Z Backward
|
20 |
+
|
21 |
+
u, v = coordinate.unbind(-1)
|
22 |
+
theta = (2 * torch.pi * u) - torch.pi
|
23 |
+
phi = torch.pi * v
|
24 |
+
|
25 |
+
dir = torch.stack(
|
26 |
+
[
|
27 |
+
theta.sin() * phi.sin(),
|
28 |
+
phi.cos(),
|
29 |
+
-1 * theta.cos() * phi.sin(),
|
30 |
+
],
|
31 |
+
-1,
|
32 |
+
)
|
33 |
+
return dir
|
34 |
+
|
35 |
+
|
36 |
+
def _get_sample_coordinates(
|
37 |
+
resolution: List[int], device: Optional[torch.device] = None
|
38 |
+
) -> Float[Tensor, "H W 2"]:
|
39 |
+
return torch.stack(
|
40 |
+
torch.meshgrid(
|
41 |
+
(torch.arange(resolution[1], device=device) + 0.5) / resolution[1],
|
42 |
+
(torch.arange(resolution[0], device=device) + 0.5) / resolution[0],
|
43 |
+
indexing="xy",
|
44 |
+
),
|
45 |
+
-1,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
class RENIEnvMap(BaseModule):
|
50 |
+
@dataclass
|
51 |
+
class Config(BaseModule.Config):
|
52 |
+
reni_config: dict = field(default_factory=dict)
|
53 |
+
resolution: int = 128
|
54 |
+
|
55 |
+
cfg: Config
|
56 |
+
|
57 |
+
def configure(self):
|
58 |
+
self.field = RENIField(self.cfg.reni_config)
|
59 |
+
resolution = (self.cfg.resolution, self.cfg.resolution * 2)
|
60 |
+
sample_directions = _direction_from_coordinate(
|
61 |
+
_get_sample_coordinates(resolution)
|
62 |
+
)
|
63 |
+
self.img_shape = sample_directions.shape[:-1]
|
64 |
+
|
65 |
+
sample_directions_flat = sample_directions.view(-1, 3)
|
66 |
+
# Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis
|
67 |
+
sample_directions_flat = torch.stack(
|
68 |
+
[
|
69 |
+
sample_directions_flat[:, 0],
|
70 |
+
-sample_directions_flat[:, 2],
|
71 |
+
sample_directions_flat[:, 1],
|
72 |
+
],
|
73 |
+
-1,
|
74 |
+
)
|
75 |
+
self.sample_directions = torch.nn.Parameter(
|
76 |
+
sample_directions_flat, requires_grad=False
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
latent_codes: Float[Tensor, "B latent_dim 3"],
|
82 |
+
rotation: Optional[Float[Tensor, "B 3 3"]] = None,
|
83 |
+
scale: Optional[Float[Tensor, "B"]] = None,
|
84 |
+
) -> Dict[str, Tensor]:
|
85 |
+
return {
|
86 |
+
k: v.view(latent_codes.shape[0], *self.img_shape, -1)
|
87 |
+
for k, v in self.field(
|
88 |
+
self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1),
|
89 |
+
latent_codes,
|
90 |
+
rotation=rotation,
|
91 |
+
scale=scale,
|
92 |
+
).items()
|
93 |
+
}
|
spar3d/models/illumination/reni/field.py
ADDED
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The University of York. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Modified by Mark Boss
|
16 |
+
|
17 |
+
"""RENI field"""
|
18 |
+
|
19 |
+
import contextlib
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Dict, Literal, Optional
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from einops.layers.torch import Rearrange
|
25 |
+
from jaxtyping import Float
|
26 |
+
from torch import Tensor, nn
|
27 |
+
|
28 |
+
from spar3d.models.network import get_activation_module, trunc_exp
|
29 |
+
from spar3d.models.utils import BaseModule
|
30 |
+
|
31 |
+
from .components.film_siren import FiLMSiren
|
32 |
+
from .components.siren import Siren
|
33 |
+
from .components.transformer_decoder import Decoder
|
34 |
+
from .components.vn_layers import VNInvariant, VNLinear
|
35 |
+
|
36 |
+
# from nerfstudio.cameras.rays import RaySamples
|
37 |
+
|
38 |
+
|
39 |
+
def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor:
|
40 |
+
"""Computes the expected value of sin(y) where y ~ N(x_means, x_vars)
|
41 |
+
|
42 |
+
Args:
|
43 |
+
x_means: Mean values.
|
44 |
+
x_vars: Variance of values.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
torch.Tensor: The expected value of sin.
|
48 |
+
"""
|
49 |
+
|
50 |
+
return torch.exp(-0.5 * x_vars) * torch.sin(x_means)
|
51 |
+
|
52 |
+
|
53 |
+
class NeRFEncoding(torch.nn.Module):
|
54 |
+
"""Multi-scale sinousoidal encodings. Support ``integrated positional encodings`` if covariances are provided.
|
55 |
+
Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
in_dim: Input dimension of tensor
|
59 |
+
num_frequencies: Number of encoded frequencies per axis
|
60 |
+
min_freq_exp: Minimum frequency exponent
|
61 |
+
max_freq_exp: Maximum frequency exponent
|
62 |
+
include_input: Append the input coordinate to the encoding
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
in_dim: int,
|
68 |
+
num_frequencies: int,
|
69 |
+
min_freq_exp: float,
|
70 |
+
max_freq_exp: float,
|
71 |
+
include_input: bool = False,
|
72 |
+
off_axis: bool = False,
|
73 |
+
) -> None:
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.in_dim = in_dim
|
77 |
+
self.num_frequencies = num_frequencies
|
78 |
+
self.min_freq = min_freq_exp
|
79 |
+
self.max_freq = max_freq_exp
|
80 |
+
self.include_input = include_input
|
81 |
+
|
82 |
+
self.off_axis = off_axis
|
83 |
+
|
84 |
+
self.P = torch.tensor(
|
85 |
+
[
|
86 |
+
[0.8506508, 0, 0.5257311],
|
87 |
+
[0.809017, 0.5, 0.309017],
|
88 |
+
[0.5257311, 0.8506508, 0],
|
89 |
+
[1, 0, 0],
|
90 |
+
[0.809017, 0.5, -0.309017],
|
91 |
+
[0.8506508, 0, -0.5257311],
|
92 |
+
[0.309017, 0.809017, -0.5],
|
93 |
+
[0, 0.5257311, -0.8506508],
|
94 |
+
[0.5, 0.309017, -0.809017],
|
95 |
+
[0, 1, 0],
|
96 |
+
[-0.5257311, 0.8506508, 0],
|
97 |
+
[-0.309017, 0.809017, -0.5],
|
98 |
+
[0, 0.5257311, 0.8506508],
|
99 |
+
[-0.309017, 0.809017, 0.5],
|
100 |
+
[0.309017, 0.809017, 0.5],
|
101 |
+
[0.5, 0.309017, 0.809017],
|
102 |
+
[0.5, -0.309017, 0.809017],
|
103 |
+
[0, 0, 1],
|
104 |
+
[-0.5, 0.309017, 0.809017],
|
105 |
+
[-0.809017, 0.5, 0.309017],
|
106 |
+
[-0.809017, 0.5, -0.309017],
|
107 |
+
]
|
108 |
+
).T
|
109 |
+
|
110 |
+
def get_out_dim(self) -> int:
|
111 |
+
if self.in_dim is None:
|
112 |
+
raise ValueError("Input dimension has not been set")
|
113 |
+
out_dim = self.in_dim * self.num_frequencies * 2
|
114 |
+
|
115 |
+
if self.off_axis:
|
116 |
+
out_dim = self.P.shape[1] * self.num_frequencies * 2
|
117 |
+
|
118 |
+
if self.include_input:
|
119 |
+
out_dim += self.in_dim
|
120 |
+
return out_dim
|
121 |
+
|
122 |
+
def forward(
|
123 |
+
self,
|
124 |
+
in_tensor: Float[Tensor, "*b input_dim"],
|
125 |
+
covs: Optional[Float[Tensor, "*b input_dim input_dim"]] = None,
|
126 |
+
) -> Float[Tensor, "*b output_dim"]:
|
127 |
+
"""Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed
|
128 |
+
in mip-NeRF.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
in_tensor: For best performance, the input tensor should be between 0 and 1.
|
132 |
+
covs: Covariances of input points.
|
133 |
+
Returns:
|
134 |
+
Output values will be between -1 and 1
|
135 |
+
"""
|
136 |
+
# TODO check scaling here but just comment it for now
|
137 |
+
# in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi]
|
138 |
+
freqs = 2 ** torch.linspace(
|
139 |
+
self.min_freq, self.max_freq, self.num_frequencies
|
140 |
+
).to(in_tensor.device)
|
141 |
+
# freqs = 2 ** (
|
142 |
+
# torch.sin(torch.linspace(self.min_freq, torch.pi / 2.0, self.num_frequencies)) * self.max_freq
|
143 |
+
# ).to(in_tensor.device)
|
144 |
+
# freqs = 2 ** (
|
145 |
+
# torch.linspace(self.min_freq, 1.0, self.num_frequencies).to(in_tensor.device) ** 0.2 * self.max_freq
|
146 |
+
# )
|
147 |
+
|
148 |
+
if self.off_axis:
|
149 |
+
scaled_inputs = (
|
150 |
+
torch.matmul(in_tensor, self.P.to(in_tensor.device))[..., None] * freqs
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
scaled_inputs = (
|
154 |
+
in_tensor[..., None] * freqs
|
155 |
+
) # [..., "input_dim", "num_scales"]
|
156 |
+
scaled_inputs = scaled_inputs.view(
|
157 |
+
*scaled_inputs.shape[:-2], -1
|
158 |
+
) # [..., "input_dim" * "num_scales"]
|
159 |
+
|
160 |
+
if covs is None:
|
161 |
+
encoded_inputs = torch.sin(
|
162 |
+
torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
input_var = (
|
166 |
+
torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None]
|
167 |
+
* freqs[None, :] ** 2
|
168 |
+
)
|
169 |
+
input_var = input_var.reshape((*input_var.shape[:-2], -1))
|
170 |
+
encoded_inputs = expected_sin(
|
171 |
+
torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1),
|
172 |
+
torch.cat(2 * [input_var], dim=-1),
|
173 |
+
)
|
174 |
+
|
175 |
+
if self.include_input:
|
176 |
+
encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1)
|
177 |
+
return encoded_inputs
|
178 |
+
|
179 |
+
|
180 |
+
class RENIField(BaseModule):
|
181 |
+
@dataclass
|
182 |
+
class Config(BaseModule.Config):
|
183 |
+
"""Configuration for model instantiation"""
|
184 |
+
|
185 |
+
fixed_decoder: bool = False
|
186 |
+
"""Whether to fix the decoder weights"""
|
187 |
+
equivariance: str = "SO2"
|
188 |
+
"""Type of equivariance to use: None, SO2, SO3"""
|
189 |
+
axis_of_invariance: str = "y"
|
190 |
+
"""Which axis should SO2 equivariance be invariant to: x, y, z"""
|
191 |
+
invariant_function: str = "GramMatrix"
|
192 |
+
"""Type of invariant function to use: GramMatrix, VN"""
|
193 |
+
conditioning: str = "Concat"
|
194 |
+
"""Type of conditioning to use: FiLM, Concat, Attention"""
|
195 |
+
positional_encoding: str = "NeRF"
|
196 |
+
"""Type of positional encoding to use. Currently only NeRF is supported"""
|
197 |
+
encoded_input: str = "Directions"
|
198 |
+
"""Type of input to encode: None, Directions, Conditioning, Both"""
|
199 |
+
latent_dim: int = 36
|
200 |
+
"""Dimensionality of latent code, N for a latent code size of (N x 3)"""
|
201 |
+
hidden_layers: int = 3
|
202 |
+
"""Number of hidden layers"""
|
203 |
+
hidden_features: int = 128
|
204 |
+
"""Number of hidden features"""
|
205 |
+
mapping_layers: int = 3
|
206 |
+
"""Number of mapping layers"""
|
207 |
+
mapping_features: int = 128
|
208 |
+
"""Number of mapping features"""
|
209 |
+
num_attention_heads: int = 8
|
210 |
+
"""Number of attention heads"""
|
211 |
+
num_attention_layers: int = 3
|
212 |
+
"""Number of attention layers"""
|
213 |
+
out_features: int = 3 # RGB
|
214 |
+
"""Number of output features"""
|
215 |
+
last_layer_linear: bool = False
|
216 |
+
"""Whether to use a linear layer as the last layer"""
|
217 |
+
output_activation: str = "exp"
|
218 |
+
"""Activation function for output layer: sigmoid, tanh, relu, exp, None"""
|
219 |
+
first_omega_0: float = 30.0
|
220 |
+
"""Omega_0 for first layer"""
|
221 |
+
hidden_omega_0: float = 30.0
|
222 |
+
"""Omega_0 for hidden layers"""
|
223 |
+
fixed_decoder: bool = False
|
224 |
+
"""Whether to fix the decoder weights"""
|
225 |
+
old_implementation: bool = False
|
226 |
+
"""Whether to match implementation of old RENI, when using old checkpoints"""
|
227 |
+
|
228 |
+
cfg: Config
|
229 |
+
|
230 |
+
def configure(self):
|
231 |
+
self.equivariance = self.cfg.equivariance
|
232 |
+
self.conditioning = self.cfg.conditioning
|
233 |
+
self.latent_dim = self.cfg.latent_dim
|
234 |
+
self.hidden_layers = self.cfg.hidden_layers
|
235 |
+
self.hidden_features = self.cfg.hidden_features
|
236 |
+
self.mapping_layers = self.cfg.mapping_layers
|
237 |
+
self.mapping_features = self.cfg.mapping_features
|
238 |
+
self.out_features = self.cfg.out_features
|
239 |
+
self.last_layer_linear = self.cfg.last_layer_linear
|
240 |
+
self.output_activation = self.cfg.output_activation
|
241 |
+
self.first_omega_0 = self.cfg.first_omega_0
|
242 |
+
self.hidden_omega_0 = self.cfg.hidden_omega_0
|
243 |
+
self.old_implementation = self.cfg.old_implementation
|
244 |
+
self.axis_of_invariance = ["x", "y", "z"].index(self.cfg.axis_of_invariance)
|
245 |
+
|
246 |
+
self.fixed_decoder = self.cfg.fixed_decoder
|
247 |
+
if self.cfg.invariant_function == "GramMatrix":
|
248 |
+
self.invariant_function = self.gram_matrix_invariance
|
249 |
+
else:
|
250 |
+
self.vn_proj_in = nn.Sequential(
|
251 |
+
Rearrange("... c -> ... 1 c"),
|
252 |
+
VNLinear(dim_in=1, dim_out=1, bias_epsilon=0),
|
253 |
+
)
|
254 |
+
dim_coor = 2 if self.cfg.equivariance == "SO2" else 3
|
255 |
+
self.vn_invar = VNInvariant(dim=1, dim_coor=dim_coor)
|
256 |
+
self.invariant_function = self.vn_invariance
|
257 |
+
|
258 |
+
self.network = self.setup_network()
|
259 |
+
|
260 |
+
if self.fixed_decoder:
|
261 |
+
for param in self.network.parameters():
|
262 |
+
param.requires_grad = False
|
263 |
+
|
264 |
+
if self.cfg.invariant_function == "VN":
|
265 |
+
for param in self.vn_proj_in.parameters():
|
266 |
+
param.requires_grad = False
|
267 |
+
for param in self.vn_invar.parameters():
|
268 |
+
param.requires_grad = False
|
269 |
+
|
270 |
+
@contextlib.contextmanager
|
271 |
+
def hold_decoder_fixed(self):
|
272 |
+
"""Context manager to fix the decoder weights
|
273 |
+
|
274 |
+
Example usage:
|
275 |
+
```
|
276 |
+
with instance_of_RENIField.hold_decoder_fixed():
|
277 |
+
# do stuff
|
278 |
+
```
|
279 |
+
"""
|
280 |
+
prev_state_network = {
|
281 |
+
name: p.requires_grad for name, p in self.network.named_parameters()
|
282 |
+
}
|
283 |
+
for param in self.network.parameters():
|
284 |
+
param.requires_grad = False
|
285 |
+
if self.cfg.invariant_function == "VN":
|
286 |
+
prev_state_proj_in = {
|
287 |
+
k: p.requires_grad for k, p in self.vn_proj_in.named_parameters()
|
288 |
+
}
|
289 |
+
prev_state_invar = {
|
290 |
+
k: p.requires_grad for k, p in self.vn_invar.named_parameters()
|
291 |
+
}
|
292 |
+
for param in self.vn_proj_in.parameters():
|
293 |
+
param.requires_grad = False
|
294 |
+
for param in self.vn_invar.parameters():
|
295 |
+
param.requires_grad = False
|
296 |
+
|
297 |
+
prev_decoder_state = self.fixed_decoder
|
298 |
+
self.fixed_decoder = True
|
299 |
+
try:
|
300 |
+
yield
|
301 |
+
finally:
|
302 |
+
# Restore the previous requires_grad state
|
303 |
+
for name, param in self.network.named_parameters():
|
304 |
+
param.requires_grad = prev_state_network[name]
|
305 |
+
if self.cfg.invariant_function == "VN":
|
306 |
+
for name, param in self.vn_proj_in.named_parameters():
|
307 |
+
param.requires_grad_(prev_state_proj_in[name])
|
308 |
+
for name, param in self.vn_invar.named_parameters():
|
309 |
+
param.requires_grad_(prev_state_invar[name])
|
310 |
+
self.fixed_decoder = prev_decoder_state
|
311 |
+
|
312 |
+
def vn_invariance(
|
313 |
+
self,
|
314 |
+
Z: Float[Tensor, "B latent_dim 3"],
|
315 |
+
D: Float[Tensor, "B num_rays 3"],
|
316 |
+
equivariance: Literal["None", "SO2", "SO3"] = "SO2",
|
317 |
+
axis_of_invariance: int = 1,
|
318 |
+
):
|
319 |
+
"""Generates a batched invariant representation from latent code Z and direction coordinates D.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
Z: [B, latent_dim, 3] - Latent code.
|
323 |
+
D: [B num_rays, 3] - Direction coordinates.
|
324 |
+
equivariance: The type of equivariance to use. Options are 'None', 'SO2', 'SO3'.
|
325 |
+
axis_of_invariance: The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis).
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
Tuple[Tensor, Tensor]: directional_input, conditioning_input
|
329 |
+
"""
|
330 |
+
assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2."
|
331 |
+
other_axes = [i for i in range(3) if i != axis_of_invariance]
|
332 |
+
|
333 |
+
B, latent_dim, _ = Z.shape
|
334 |
+
_, num_rays, _ = D.shape
|
335 |
+
|
336 |
+
if equivariance == "None":
|
337 |
+
# get inner product between latent code and direction coordinates
|
338 |
+
innerprod = torch.sum(
|
339 |
+
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
|
340 |
+
) # [B, num_rays, latent_dim]
|
341 |
+
z_input = (
|
342 |
+
Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3)
|
343 |
+
) # [B, num_rays, latent_dim * 3]
|
344 |
+
return innerprod, z_input
|
345 |
+
|
346 |
+
if equivariance == "SO2":
|
347 |
+
z_other = torch.stack(
|
348 |
+
(Z[..., other_axes[0]], Z[..., other_axes[1]]), -1
|
349 |
+
) # [B, latent_dim, 2]
|
350 |
+
d_other = torch.stack(
|
351 |
+
(D[..., other_axes[0]], D[..., other_axes[1]]), -1
|
352 |
+
).unsqueeze(2) # [B, num_rays, 1, 2]
|
353 |
+
d_other = d_other.expand(
|
354 |
+
B, num_rays, latent_dim, 2
|
355 |
+
) # [B, num_rays, latent_dim, 2]
|
356 |
+
|
357 |
+
z_other_emb = self.vn_proj_in(z_other) # [B, latent_dim, 1, 2]
|
358 |
+
z_other_invar = self.vn_invar(z_other_emb) # [B, latent_dim, 2]
|
359 |
+
|
360 |
+
# Get invariant component of Z along the axis of invariance
|
361 |
+
z_invar = Z[..., axis_of_invariance].unsqueeze(-1) # [B, latent_dim, 1]
|
362 |
+
|
363 |
+
# Innerproduct between projection of Z and D on the plane orthogonal to the axis of invariance.
|
364 |
+
# This encodes the rotational information. This is rotation-equivariant to rotations of either Z
|
365 |
+
# or D and is invariant to rotations of both Z and D.
|
366 |
+
innerprod = (z_other.unsqueeze(1) * d_other).sum(
|
367 |
+
dim=-1
|
368 |
+
) # [B, num_rays, latent_dim]
|
369 |
+
|
370 |
+
# Compute norm along the axes orthogonal to the axis of invariance
|
371 |
+
d_other_norm = torch.sqrt(
|
372 |
+
D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2
|
373 |
+
).unsqueeze(-1) # [B num_rays, 1]
|
374 |
+
|
375 |
+
# Get invariant component of D along the axis of invariance
|
376 |
+
d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1]
|
377 |
+
|
378 |
+
directional_input = torch.cat(
|
379 |
+
(innerprod, d_invar, d_other_norm), -1
|
380 |
+
) # [B, num_rays, latent_dim + 2]
|
381 |
+
conditioning_input = (
|
382 |
+
torch.cat((z_other_invar, z_invar), dim=-1)
|
383 |
+
.flatten(1)
|
384 |
+
.unsqueeze(1)
|
385 |
+
.expand(B, num_rays, latent_dim * 3)
|
386 |
+
) # [B, num_rays, latent_dim * 3]
|
387 |
+
|
388 |
+
return directional_input, conditioning_input
|
389 |
+
|
390 |
+
if equivariance == "SO3":
|
391 |
+
z = self.vn_proj_in(Z) # [B, latent_dim, 1, 3]
|
392 |
+
z_invar = self.vn_invar(z) # [B, latent_dim, 3]
|
393 |
+
conditioning_input = (
|
394 |
+
z_invar.flatten(1).unsqueeze(1).expand(B, num_rays, latent_dim)
|
395 |
+
) # [B, num_rays, latent_dim * 3]
|
396 |
+
# D [B, num_rays, 3] -> [B, num_rays, 1, 3]
|
397 |
+
# Z [B, latent_dim, 3] -> [B, 1, latent_dim, 3]
|
398 |
+
innerprod = torch.sum(
|
399 |
+
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
|
400 |
+
) # [B, num_rays, latent_dim]
|
401 |
+
return innerprod, conditioning_input
|
402 |
+
|
403 |
+
def gram_matrix_invariance(
|
404 |
+
self,
|
405 |
+
Z: Float[Tensor, "B latent_dim 3"],
|
406 |
+
D: Float[Tensor, "B num_rays 3"],
|
407 |
+
equivariance: Literal["None", "SO2", "SO3"] = "SO2",
|
408 |
+
axis_of_invariance: int = 1,
|
409 |
+
):
|
410 |
+
"""Generates an invariant representation from latent code Z and direction coordinates D.
|
411 |
+
|
412 |
+
Args:
|
413 |
+
Z (torch.Tensor): Latent code (B x latent_dim x 3)
|
414 |
+
D (torch.Tensor): Direction coordinates (B x num_rays x 3)
|
415 |
+
equivariance (str): Type of equivariance to use. Options are 'none', 'SO2', and 'SO3'
|
416 |
+
axis_of_invariance (int): The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis).
|
417 |
+
Default is 1 (y-axis).
|
418 |
+
Returns:
|
419 |
+
torch.Tensor: Invariant representation
|
420 |
+
"""
|
421 |
+
assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2."
|
422 |
+
other_axes = [i for i in range(3) if i != axis_of_invariance]
|
423 |
+
|
424 |
+
B, latent_dim, _ = Z.shape
|
425 |
+
_, num_rays, _ = D.shape
|
426 |
+
|
427 |
+
if equivariance == "None":
|
428 |
+
# get inner product between latent code and direction coordinates
|
429 |
+
innerprod = torch.sum(
|
430 |
+
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
|
431 |
+
) # [B, num_rays, latent_dim]
|
432 |
+
z_input = (
|
433 |
+
Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3)
|
434 |
+
) # [B, num_rays, latent_dim * 3]
|
435 |
+
return innerprod, z_input
|
436 |
+
|
437 |
+
if equivariance == "SO2":
|
438 |
+
# Select components along axes orthogonal to the axis of invariance
|
439 |
+
z_other = torch.stack(
|
440 |
+
(Z[..., other_axes[0]], Z[..., other_axes[1]]), -1
|
441 |
+
) # [B, latent_dim, 2]
|
442 |
+
d_other = torch.stack(
|
443 |
+
(D[..., other_axes[0]], D[..., other_axes[1]]), -1
|
444 |
+
).unsqueeze(2) # [B, num_rays, 1, 2]
|
445 |
+
d_other = d_other.expand(
|
446 |
+
B, num_rays, latent_dim, 2
|
447 |
+
) # size becomes [B, num_rays, latent_dim, 2]
|
448 |
+
|
449 |
+
# Invariant representation of Z, gram matrix G=Z*Z' is size num_rays x latent_dim x latent_dim
|
450 |
+
G = torch.bmm(z_other, torch.transpose(z_other, 1, 2))
|
451 |
+
|
452 |
+
# Flatten G to be size B x latent_dim^2
|
453 |
+
z_other_invar = G.flatten(start_dim=1)
|
454 |
+
|
455 |
+
# Get invariant component of Z along the axis of invariance
|
456 |
+
z_invar = Z[..., axis_of_invariance] # [B, latent_dim]
|
457 |
+
|
458 |
+
# Innerprod is size num_rays x latent_dim
|
459 |
+
innerprod = (z_other.unsqueeze(1) * d_other).sum(
|
460 |
+
dim=-1
|
461 |
+
) # [B, num_rays, latent_dim]
|
462 |
+
|
463 |
+
# Compute norm along the axes orthogonal to the axis of invariance
|
464 |
+
d_other_norm = torch.sqrt(
|
465 |
+
D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2
|
466 |
+
).unsqueeze(-1) # [B, num_rays, 1]
|
467 |
+
|
468 |
+
# Get invariant component of D along the axis of invariance
|
469 |
+
d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1]
|
470 |
+
|
471 |
+
if not self.old_implementation:
|
472 |
+
directional_input = torch.cat(
|
473 |
+
(innerprod, d_invar, d_other_norm), -1
|
474 |
+
) # [B, num_rays, latent_dim + 2]
|
475 |
+
conditioning_input = (
|
476 |
+
torch.cat((z_other_invar, z_invar), -1)
|
477 |
+
.unsqueeze(1)
|
478 |
+
.expand(B, num_rays, latent_dim * 3)
|
479 |
+
) # [B, num_rays, latent_dim^2 + latent_dim]
|
480 |
+
else:
|
481 |
+
# this is matching the previous implementation of RENI, needed if using old checkpoints
|
482 |
+
z_other_invar = z_other_invar.unsqueeze(1).expand(B, num_rays, -1)
|
483 |
+
z_invar = z_invar.unsqueeze(1).expand(B, num_rays, -1)
|
484 |
+
return torch.cat(
|
485 |
+
(innerprod, z_other_invar, d_other_norm, z_invar, d_invar), 1
|
486 |
+
)
|
487 |
+
|
488 |
+
return directional_input, conditioning_input
|
489 |
+
|
490 |
+
if equivariance == "SO3":
|
491 |
+
G = Z @ torch.transpose(Z, 1, 2) # [B, latent_dim, latent_dim]
|
492 |
+
innerprod = torch.sum(
|
493 |
+
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1
|
494 |
+
) # [B, num_rays, latent_dim]
|
495 |
+
z_invar = (
|
496 |
+
G.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, -1)
|
497 |
+
) # [B, num_rays, latent_dim^2]
|
498 |
+
return innerprod, z_invar
|
499 |
+
|
500 |
+
def setup_network(self):
|
501 |
+
"""Sets up the network architecture"""
|
502 |
+
base_input_dims = {
|
503 |
+
"VN": {
|
504 |
+
"None": {
|
505 |
+
"direction": self.latent_dim,
|
506 |
+
"conditioning": self.latent_dim * 3,
|
507 |
+
},
|
508 |
+
"SO2": {
|
509 |
+
"direction": self.latent_dim + 2,
|
510 |
+
"conditioning": self.latent_dim * 3,
|
511 |
+
},
|
512 |
+
"SO3": {
|
513 |
+
"direction": self.latent_dim,
|
514 |
+
"conditioning": self.latent_dim * 3,
|
515 |
+
},
|
516 |
+
},
|
517 |
+
"GramMatrix": {
|
518 |
+
"None": {
|
519 |
+
"direction": self.latent_dim,
|
520 |
+
"conditioning": self.latent_dim * 3,
|
521 |
+
},
|
522 |
+
"SO2": {
|
523 |
+
"direction": self.latent_dim + 2,
|
524 |
+
"conditioning": self.latent_dim**2 + self.latent_dim,
|
525 |
+
},
|
526 |
+
"SO3": {
|
527 |
+
"direction": self.latent_dim,
|
528 |
+
"conditioning": self.latent_dim**2,
|
529 |
+
},
|
530 |
+
},
|
531 |
+
}
|
532 |
+
|
533 |
+
# Extract the necessary input dimensions
|
534 |
+
input_types = ["direction", "conditioning"]
|
535 |
+
input_dims = {
|
536 |
+
key: base_input_dims[self.cfg.invariant_function][self.cfg.equivariance][
|
537 |
+
key
|
538 |
+
]
|
539 |
+
for key in input_types
|
540 |
+
}
|
541 |
+
|
542 |
+
# Helper function to create NeRF encoding
|
543 |
+
def create_nerf_encoding(in_dim):
|
544 |
+
return NeRFEncoding(
|
545 |
+
in_dim=in_dim,
|
546 |
+
num_frequencies=2,
|
547 |
+
min_freq_exp=0.0,
|
548 |
+
max_freq_exp=2.0,
|
549 |
+
include_input=True,
|
550 |
+
)
|
551 |
+
|
552 |
+
# Dictionary-based encoding setup
|
553 |
+
encoding_setup = {
|
554 |
+
"None": [],
|
555 |
+
"Conditioning": ["conditioning"],
|
556 |
+
"Directions": ["direction"],
|
557 |
+
"Both": ["direction", "conditioning"],
|
558 |
+
}
|
559 |
+
|
560 |
+
# Setting up the required encodings
|
561 |
+
for input_type in encoding_setup.get(self.cfg.encoded_input, []):
|
562 |
+
# create self.{input_type}_encoding and update input_dims
|
563 |
+
setattr(
|
564 |
+
self,
|
565 |
+
f"{input_type}_encoding",
|
566 |
+
create_nerf_encoding(input_dims[input_type]),
|
567 |
+
)
|
568 |
+
input_dims[input_type] = getattr(
|
569 |
+
self, f"{input_type}_encoding"
|
570 |
+
).get_out_dim()
|
571 |
+
|
572 |
+
output_activation = get_activation_module(self.cfg.output_activation)
|
573 |
+
|
574 |
+
network = None
|
575 |
+
if self.conditioning == "Concat":
|
576 |
+
network = Siren(
|
577 |
+
in_dim=input_dims["direction"] + input_dims["conditioning"],
|
578 |
+
hidden_layers=self.hidden_layers,
|
579 |
+
hidden_features=self.hidden_features,
|
580 |
+
out_dim=self.out_features,
|
581 |
+
outermost_linear=self.last_layer_linear,
|
582 |
+
first_omega_0=self.first_omega_0,
|
583 |
+
hidden_omega_0=self.hidden_omega_0,
|
584 |
+
out_activation=output_activation,
|
585 |
+
)
|
586 |
+
elif self.conditioning == "FiLM":
|
587 |
+
network = FiLMSiren(
|
588 |
+
in_dim=input_dims["direction"],
|
589 |
+
hidden_layers=self.hidden_layers,
|
590 |
+
hidden_features=self.hidden_features,
|
591 |
+
mapping_network_in_dim=input_dims["conditioning"],
|
592 |
+
mapping_network_layers=self.mapping_layers,
|
593 |
+
mapping_network_features=self.mapping_features,
|
594 |
+
out_dim=self.out_features,
|
595 |
+
outermost_linear=True,
|
596 |
+
out_activation=output_activation,
|
597 |
+
)
|
598 |
+
elif self.conditioning == "Attention":
|
599 |
+
# transformer where K, V is from conditioning input and Q is from pos encoded directional input
|
600 |
+
network = Decoder(
|
601 |
+
in_dim=input_dims["direction"],
|
602 |
+
conditioning_input_dim=input_dims["conditioning"],
|
603 |
+
hidden_features=self.cfg.hidden_features,
|
604 |
+
num_heads=self.cfg.num_attention_heads,
|
605 |
+
num_layers=self.cfg.num_attention_layers,
|
606 |
+
out_activation=output_activation,
|
607 |
+
)
|
608 |
+
assert network is not None, "unknown conditioning type"
|
609 |
+
return network
|
610 |
+
|
611 |
+
def apply_positional_encoding(self, directional_input, conditioning_input):
|
612 |
+
# conditioning on just invariant directional input
|
613 |
+
if self.cfg.encoded_input == "Conditioning":
|
614 |
+
conditioning_input = self.conditioning_encoding(
|
615 |
+
conditioning_input
|
616 |
+
) # [num_rays, embedding_dim]
|
617 |
+
elif self.cfg.encoded_input == "Directions":
|
618 |
+
directional_input = self.direction_encoding(
|
619 |
+
directional_input
|
620 |
+
) # [num_rays, embedding_dim]
|
621 |
+
elif self.cfg.encoded_input == "Both":
|
622 |
+
directional_input = self.direction_encoding(directional_input)
|
623 |
+
conditioning_input = self.conditioning_encoding(conditioning_input)
|
624 |
+
|
625 |
+
return directional_input, conditioning_input
|
626 |
+
|
627 |
+
def get_outputs(
|
628 |
+
self,
|
629 |
+
rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore
|
630 |
+
latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore
|
631 |
+
rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore
|
632 |
+
scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore
|
633 |
+
) -> Dict[str, Tensor]:
|
634 |
+
"""Returns the outputs of the field.
|
635 |
+
|
636 |
+
Args:
|
637 |
+
ray_samples: [batch_size num_rays 3]
|
638 |
+
latent_codes: [batch_size, latent_dim, 3]
|
639 |
+
rotation: [batch_size, 3, 3]
|
640 |
+
scale: [batch_size]
|
641 |
+
"""
|
642 |
+
if rotation is not None:
|
643 |
+
if len(rotation.shape) == 3: # [batch_size, 3, 3]
|
644 |
+
# Expand latent_codes to match [batch_size, latent_dim, 3]
|
645 |
+
latent_codes = torch.einsum(
|
646 |
+
"bik,blk->bli",
|
647 |
+
rotation,
|
648 |
+
latent_codes,
|
649 |
+
)
|
650 |
+
else:
|
651 |
+
raise NotImplementedError(
|
652 |
+
"Unsupported rotation shape. Expected [batch_size, 3, 3]."
|
653 |
+
)
|
654 |
+
|
655 |
+
B, num_rays, _ = rays_d.shape
|
656 |
+
_, latent_dim, _ = latent_codes.shape
|
657 |
+
|
658 |
+
if not self.old_implementation:
|
659 |
+
directional_input, conditioning_input = self.invariant_function(
|
660 |
+
latent_codes,
|
661 |
+
rays_d,
|
662 |
+
equivariance=self.equivariance,
|
663 |
+
axis_of_invariance=self.axis_of_invariance,
|
664 |
+
) # [B, num_rays, 3]
|
665 |
+
|
666 |
+
if self.cfg.positional_encoding == "NeRF":
|
667 |
+
directional_input, conditioning_input = self.apply_positional_encoding(
|
668 |
+
directional_input, conditioning_input
|
669 |
+
)
|
670 |
+
|
671 |
+
if self.conditioning == "Concat":
|
672 |
+
model_outputs = self.network(
|
673 |
+
torch.cat((directional_input, conditioning_input), dim=-1).reshape(
|
674 |
+
B * num_rays, -1
|
675 |
+
)
|
676 |
+
).view(B, num_rays, 3) # returns -> [B num_rays, 3]
|
677 |
+
elif self.conditioning == "FiLM":
|
678 |
+
model_outputs = self.network(
|
679 |
+
directional_input.reshape(B * num_rays, -1),
|
680 |
+
conditioning_input.reshape(B * num_rays, -1),
|
681 |
+
).view(B, num_rays, 3) # returns -> [B num_rays, 3]
|
682 |
+
elif self.conditioning == "Attention":
|
683 |
+
model_outputs = self.network(
|
684 |
+
directional_input.reshape(B * num_rays, -1),
|
685 |
+
conditioning_input.reshape(B * num_rays, -1),
|
686 |
+
).view(B, num_rays, 3) # returns -> [B num_rays, 3]
|
687 |
+
else:
|
688 |
+
# in the old implementation directions were sampled with y-up not z-up so need to swap y and z in directions
|
689 |
+
directions = torch.stack(
|
690 |
+
(rays_d[..., 0], rays_d[..., 2], rays_d[..., 1]), -1
|
691 |
+
)
|
692 |
+
model_input = self.invariant_function(
|
693 |
+
latent_codes,
|
694 |
+
directions,
|
695 |
+
equivariance=self.equivariance,
|
696 |
+
axis_of_invariance=self.axis_of_invariance,
|
697 |
+
) # [B, num_rays, 3]
|
698 |
+
|
699 |
+
model_outputs = self.network(model_input.view(B * num_rays, -1)).view(
|
700 |
+
B, num_rays, 3
|
701 |
+
)
|
702 |
+
|
703 |
+
outputs = {}
|
704 |
+
|
705 |
+
if scale is not None:
|
706 |
+
scale = trunc_exp(scale) # [num_rays] exp to ensure positive
|
707 |
+
model_outputs = model_outputs * scale.view(-1, 1, 1) # [num_rays, 3]
|
708 |
+
|
709 |
+
outputs["rgb"] = model_outputs
|
710 |
+
|
711 |
+
return outputs
|
712 |
+
|
713 |
+
def forward(
|
714 |
+
self,
|
715 |
+
rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore
|
716 |
+
latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore
|
717 |
+
rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore
|
718 |
+
scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore
|
719 |
+
) -> Dict[str, Tensor]:
|
720 |
+
"""Evaluates spherical field for a given ray bundle and rotation.
|
721 |
+
|
722 |
+
Args:
|
723 |
+
ray_samples: [B num_rays 3]
|
724 |
+
latent_codes: [B, num_rays, latent_dim, 3]
|
725 |
+
rotation: [batch_size, 3, 3]
|
726 |
+
scale: [batch_size]
|
727 |
+
|
728 |
+
Returns:
|
729 |
+
Dict[str, Tensor]: A dictionary containing the outputs of the field.
|
730 |
+
"""
|
731 |
+
return self.get_outputs(
|
732 |
+
rays_d=rays_d,
|
733 |
+
latent_codes=latent_codes,
|
734 |
+
rotation=rotation,
|
735 |
+
scale=scale,
|
736 |
+
)
|
spar3d/models/image_estimator/clip_based_estimator.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Any, List, Optional
|
3 |
+
|
4 |
+
import alpha_clip
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
|
11 |
+
from spar3d.models.network import get_activation
|
12 |
+
from spar3d.models.utils import BaseModule
|
13 |
+
|
14 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
15 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class HeadSpec:
|
20 |
+
name: str
|
21 |
+
out_channels: int
|
22 |
+
n_hidden_layers: int
|
23 |
+
output_activation: Optional[str] = None
|
24 |
+
output_bias: float = 0.0
|
25 |
+
add_to_decoder_features: bool = False
|
26 |
+
shape: Optional[list[int]] = None
|
27 |
+
distribution_eval: str = "sample"
|
28 |
+
|
29 |
+
|
30 |
+
class ClipBasedHeadEstimator(BaseModule):
|
31 |
+
@dataclass
|
32 |
+
class Config(BaseModule.Config):
|
33 |
+
model: str = "ViT-L/14@336px"
|
34 |
+
|
35 |
+
distribution: str = "beta"
|
36 |
+
|
37 |
+
# ["mean", "mode", "sample", "sample_mean"]
|
38 |
+
distribution_eval: str = "mode"
|
39 |
+
|
40 |
+
activation: str = "relu"
|
41 |
+
hidden_features: int = 512
|
42 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
43 |
+
|
44 |
+
cfg: Config
|
45 |
+
|
46 |
+
def configure(self):
|
47 |
+
self.model, _ = alpha_clip.load(
|
48 |
+
self.cfg.model,
|
49 |
+
) # change to your own ckpt path
|
50 |
+
self.model.eval()
|
51 |
+
|
52 |
+
if not hasattr(self.model.visual, "input_resolution"):
|
53 |
+
self.img_size = 224
|
54 |
+
else:
|
55 |
+
self.img_size = self.model.visual.input_resolution
|
56 |
+
# Check if img_size is subscribable and pick the first element
|
57 |
+
if hasattr(self.img_size, "__getitem__"):
|
58 |
+
self.img_size = self.img_size[0]
|
59 |
+
|
60 |
+
# Do not add the weights in self.model to the optimizer
|
61 |
+
for param in self.model.parameters():
|
62 |
+
param.requires_grad = False
|
63 |
+
|
64 |
+
assert len(self.cfg.heads) > 0
|
65 |
+
heads = {}
|
66 |
+
for head in self.cfg.heads:
|
67 |
+
head_layers = []
|
68 |
+
in_feature = self.model.visual.output_dim
|
69 |
+
|
70 |
+
for i in range(head.n_hidden_layers):
|
71 |
+
head_layers += [
|
72 |
+
nn.Linear(
|
73 |
+
in_feature if i == 0 else self.cfg.hidden_features,
|
74 |
+
self.cfg.hidden_features,
|
75 |
+
),
|
76 |
+
self.make_activation(self.cfg.activation),
|
77 |
+
]
|
78 |
+
|
79 |
+
head_layers = [nn.Sequential(*head_layers)]
|
80 |
+
head_layers += [
|
81 |
+
nn.Sequential(
|
82 |
+
nn.Linear(
|
83 |
+
self.cfg.hidden_features,
|
84 |
+
self.cfg.hidden_features,
|
85 |
+
),
|
86 |
+
self.make_activation(self.cfg.activation),
|
87 |
+
nn.Linear(self.cfg.hidden_features, 1),
|
88 |
+
)
|
89 |
+
for _ in range(2)
|
90 |
+
]
|
91 |
+
heads[head.name] = nn.ModuleList(head_layers)
|
92 |
+
self.heads = nn.ModuleDict(heads)
|
93 |
+
|
94 |
+
def make_activation(self, activation):
|
95 |
+
if activation == "relu":
|
96 |
+
return nn.ReLU(inplace=True)
|
97 |
+
elif activation == "silu":
|
98 |
+
return nn.SiLU(inplace=True)
|
99 |
+
else:
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
def forward(
|
103 |
+
self,
|
104 |
+
cond_image: Float[Tensor, "B 1 H W 4"],
|
105 |
+
sample: bool = True,
|
106 |
+
) -> dict[str, Any]:
|
107 |
+
# Run the model
|
108 |
+
# Resize cond_image to 224
|
109 |
+
cond_image = cond_image.flatten(0, 1)
|
110 |
+
cond_image = nn.functional.interpolate(
|
111 |
+
cond_image.permute(0, 3, 1, 2),
|
112 |
+
size=(self.img_size, self.img_size),
|
113 |
+
mode="bilinear",
|
114 |
+
align_corners=False,
|
115 |
+
)
|
116 |
+
mask = cond_image[:, 3:4]
|
117 |
+
cond_image = cond_image[:, :3] * mask
|
118 |
+
cond_image = Normalize(
|
119 |
+
mean=OPENAI_DATASET_MEAN,
|
120 |
+
std=OPENAI_DATASET_STD,
|
121 |
+
)(cond_image)
|
122 |
+
mask = Normalize(0.5, 0.26)(mask).half()
|
123 |
+
image_features = self.model.visual(cond_image.half(), mask).float()
|
124 |
+
|
125 |
+
# Run the heads
|
126 |
+
outputs = {}
|
127 |
+
|
128 |
+
for head_dict in self.cfg.heads:
|
129 |
+
head_name = head_dict.name
|
130 |
+
shared_head, d1_h, d2_h = self.heads[head_name]
|
131 |
+
shared_features = shared_head(image_features)
|
132 |
+
d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
|
133 |
+
if self.cfg.distribution == "normal":
|
134 |
+
mean = d1
|
135 |
+
var = d2
|
136 |
+
if mean.shape[-1] == 1:
|
137 |
+
outputs[head_name] = torch.distributions.Normal(
|
138 |
+
mean + head_dict.output_bias,
|
139 |
+
torch.nn.functional.softplus(var),
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
outputs[head_name] = torch.distributions.MultivariateNormal(
|
143 |
+
mean + head_dict.output_bias,
|
144 |
+
torch.nn.functional.softplus(var).diag_embed(),
|
145 |
+
)
|
146 |
+
elif self.cfg.distribution == "beta":
|
147 |
+
outputs[head_name] = torch.distributions.Beta(
|
148 |
+
torch.nn.functional.softplus(d1 + head_dict.output_bias),
|
149 |
+
torch.nn.functional.softplus(d2 + head_dict.output_bias),
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
raise NotImplementedError
|
153 |
+
|
154 |
+
if sample:
|
155 |
+
for head_dict in self.cfg.heads:
|
156 |
+
head_name = head_dict.name
|
157 |
+
dist = outputs[head_name]
|
158 |
+
|
159 |
+
if head_dict.distribution_eval == "mean":
|
160 |
+
out = dist.mean
|
161 |
+
elif head_dict.distribution_eval == "mode":
|
162 |
+
out = dist.mode
|
163 |
+
elif head_dict.distribution_eval == "sample_mean":
|
164 |
+
out = dist.sample([10]).mean(-1)
|
165 |
+
else:
|
166 |
+
# use rsample if gradient is needed
|
167 |
+
out = dist.rsample() if self.training else dist.sample()
|
168 |
+
|
169 |
+
outputs[head_name] = get_activation(head_dict.output_activation)(out)
|
170 |
+
outputs[f"{head_name}_dist"] = dist
|
171 |
+
|
172 |
+
for head in self.cfg.heads:
|
173 |
+
if head.shape:
|
174 |
+
if not sample:
|
175 |
+
raise ValueError(
|
176 |
+
"Cannot reshape non-sampled probabilisitic outputs"
|
177 |
+
)
|
178 |
+
outputs[head.name] = outputs[head.name].reshape(*head.shape)
|
179 |
+
|
180 |
+
if head.add_to_decoder_features:
|
181 |
+
outputs[f"decoder_{head.name}"] = outputs[head.name]
|
182 |
+
del outputs[head.name]
|
183 |
+
|
184 |
+
return outputs
|
spar3d/models/isosurface.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from jaxtyping import Float, Integer
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
from .mesh import Mesh
|
10 |
+
|
11 |
+
|
12 |
+
class IsosurfaceHelper(nn.Module):
|
13 |
+
points_range: Tuple[float, float] = (0, 1)
|
14 |
+
|
15 |
+
@property
|
16 |
+
def grid_vertices(self) -> Float[Tensor, "N 3"]:
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
@property
|
20 |
+
def requires_instance_per_batch(self) -> bool:
|
21 |
+
return False
|
22 |
+
|
23 |
+
|
24 |
+
class MarchingTetrahedraHelper(IsosurfaceHelper):
|
25 |
+
def __init__(self, resolution: int, tets_path: str):
|
26 |
+
super().__init__()
|
27 |
+
self.resolution = resolution
|
28 |
+
self.tets_path = tets_path
|
29 |
+
|
30 |
+
self.triangle_table: Float[Tensor, "..."]
|
31 |
+
self.register_buffer(
|
32 |
+
"triangle_table",
|
33 |
+
torch.as_tensor(
|
34 |
+
[
|
35 |
+
[-1, -1, -1, -1, -1, -1],
|
36 |
+
[1, 0, 2, -1, -1, -1],
|
37 |
+
[4, 0, 3, -1, -1, -1],
|
38 |
+
[1, 4, 2, 1, 3, 4],
|
39 |
+
[3, 1, 5, -1, -1, -1],
|
40 |
+
[2, 3, 0, 2, 5, 3],
|
41 |
+
[1, 4, 0, 1, 5, 4],
|
42 |
+
[4, 2, 5, -1, -1, -1],
|
43 |
+
[4, 5, 2, -1, -1, -1],
|
44 |
+
[4, 1, 0, 4, 5, 1],
|
45 |
+
[3, 2, 0, 3, 5, 2],
|
46 |
+
[1, 3, 5, -1, -1, -1],
|
47 |
+
[4, 1, 2, 4, 3, 1],
|
48 |
+
[3, 0, 4, -1, -1, -1],
|
49 |
+
[2, 0, 1, -1, -1, -1],
|
50 |
+
[-1, -1, -1, -1, -1, -1],
|
51 |
+
],
|
52 |
+
dtype=torch.long,
|
53 |
+
),
|
54 |
+
persistent=False,
|
55 |
+
)
|
56 |
+
self.num_triangles_table: Integer[Tensor, "..."]
|
57 |
+
self.register_buffer(
|
58 |
+
"num_triangles_table",
|
59 |
+
torch.as_tensor(
|
60 |
+
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
|
61 |
+
),
|
62 |
+
persistent=False,
|
63 |
+
)
|
64 |
+
self.base_tet_edges: Integer[Tensor, "..."]
|
65 |
+
self.register_buffer(
|
66 |
+
"base_tet_edges",
|
67 |
+
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
|
68 |
+
persistent=False,
|
69 |
+
)
|
70 |
+
|
71 |
+
tets = np.load(self.tets_path)
|
72 |
+
self._grid_vertices: Float[Tensor, "..."]
|
73 |
+
self.register_buffer(
|
74 |
+
"_grid_vertices",
|
75 |
+
torch.from_numpy(tets["vertices"]).float(),
|
76 |
+
persistent=False,
|
77 |
+
)
|
78 |
+
self.indices: Integer[Tensor, "..."]
|
79 |
+
self.register_buffer(
|
80 |
+
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False
|
81 |
+
)
|
82 |
+
|
83 |
+
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
84 |
+
|
85 |
+
center_indices, boundary_indices = self.get_center_boundary_index(
|
86 |
+
self._grid_vertices
|
87 |
+
)
|
88 |
+
self.center_indices: Integer[Tensor, "..."]
|
89 |
+
self.register_buffer("center_indices", center_indices, persistent=False)
|
90 |
+
self.boundary_indices: Integer[Tensor, "..."]
|
91 |
+
self.register_buffer("boundary_indices", boundary_indices, persistent=False)
|
92 |
+
|
93 |
+
def get_center_boundary_index(self, verts):
|
94 |
+
magn = torch.sum(verts**2, dim=-1)
|
95 |
+
|
96 |
+
center_idx = torch.argmin(magn)
|
97 |
+
boundary_neg = verts == verts.max()
|
98 |
+
boundary_pos = verts == verts.min()
|
99 |
+
|
100 |
+
boundary = torch.bitwise_or(boundary_pos, boundary_neg)
|
101 |
+
boundary = torch.sum(boundary.float(), dim=-1)
|
102 |
+
|
103 |
+
boundary_idx = torch.nonzero(boundary)
|
104 |
+
return center_idx, boundary_idx.squeeze(dim=-1)
|
105 |
+
|
106 |
+
def normalize_grid_deformation(
|
107 |
+
self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
|
108 |
+
) -> Float[Tensor, "Nv 3"]:
|
109 |
+
return (
|
110 |
+
(self.points_range[1] - self.points_range[0])
|
111 |
+
/ self.resolution # half tet size is approximately 1 / self.resolution
|
112 |
+
* torch.tanh(grid_vertex_offsets)
|
113 |
+
) # FIXME: hard-coded activation
|
114 |
+
|
115 |
+
@property
|
116 |
+
def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
|
117 |
+
return self._grid_vertices
|
118 |
+
|
119 |
+
@property
|
120 |
+
def all_edges(self) -> Integer[Tensor, "Ne 2"]:
|
121 |
+
if self._all_edges is None:
|
122 |
+
# compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
|
123 |
+
edges = torch.tensor(
|
124 |
+
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
|
125 |
+
dtype=torch.long,
|
126 |
+
device=self.indices.device,
|
127 |
+
)
|
128 |
+
_all_edges = self.indices[:, edges].reshape(-1, 2)
|
129 |
+
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
|
130 |
+
_all_edges = torch.unique(_all_edges_sorted, dim=0)
|
131 |
+
self._all_edges = _all_edges
|
132 |
+
return self._all_edges
|
133 |
+
|
134 |
+
def sort_edges(self, edges_ex2):
|
135 |
+
with torch.no_grad():
|
136 |
+
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
|
137 |
+
order = order.unsqueeze(dim=1)
|
138 |
+
|
139 |
+
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
140 |
+
b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
|
141 |
+
|
142 |
+
return torch.stack([a, b], -1)
|
143 |
+
|
144 |
+
def _forward(self, pos_nx3, sdf_n, tet_fx4):
|
145 |
+
with torch.no_grad():
|
146 |
+
occ_n = sdf_n > 0
|
147 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
148 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
149 |
+
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
150 |
+
occ_sum = occ_sum[valid_tets]
|
151 |
+
|
152 |
+
# find all vertices
|
153 |
+
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
|
154 |
+
all_edges = self.sort_edges(all_edges)
|
155 |
+
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
156 |
+
|
157 |
+
unique_edges = unique_edges.long()
|
158 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
159 |
+
mapping = (
|
160 |
+
torch.ones(
|
161 |
+
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
|
162 |
+
)
|
163 |
+
* -1
|
164 |
+
)
|
165 |
+
mapping[mask_edges] = torch.arange(
|
166 |
+
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
|
167 |
+
)
|
168 |
+
idx_map = mapping[idx_map] # map edges to verts
|
169 |
+
|
170 |
+
interp_v = unique_edges[mask_edges]
|
171 |
+
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
172 |
+
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
173 |
+
edges_to_interp_sdf[:, -1] *= -1
|
174 |
+
|
175 |
+
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
176 |
+
|
177 |
+
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
178 |
+
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
179 |
+
|
180 |
+
idx_map = idx_map.reshape(-1, 6)
|
181 |
+
|
182 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
|
183 |
+
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
184 |
+
num_triangles = self.num_triangles_table[tetindex]
|
185 |
+
|
186 |
+
# Generate triangle indices
|
187 |
+
faces = torch.cat(
|
188 |
+
(
|
189 |
+
torch.gather(
|
190 |
+
input=idx_map[num_triangles == 1],
|
191 |
+
dim=1,
|
192 |
+
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
|
193 |
+
).reshape(-1, 3),
|
194 |
+
torch.gather(
|
195 |
+
input=idx_map[num_triangles == 2],
|
196 |
+
dim=1,
|
197 |
+
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
|
198 |
+
).reshape(-1, 3),
|
199 |
+
),
|
200 |
+
dim=0,
|
201 |
+
)
|
202 |
+
|
203 |
+
return verts, faces
|
204 |
+
|
205 |
+
def forward(
|
206 |
+
self,
|
207 |
+
level: Float[Tensor, "N3 1"],
|
208 |
+
deformation: Optional[Float[Tensor, "N3 3"]] = None,
|
209 |
+
) -> Mesh:
|
210 |
+
if deformation is not None:
|
211 |
+
grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
|
212 |
+
deformation
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
grid_vertices = self.grid_vertices
|
216 |
+
|
217 |
+
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
|
218 |
+
|
219 |
+
mesh = Mesh(
|
220 |
+
v_pos=v_pos,
|
221 |
+
t_pos_idx=t_pos_idx,
|
222 |
+
# extras
|
223 |
+
grid_vertices=grid_vertices,
|
224 |
+
tet_edges=self.all_edges,
|
225 |
+
grid_level=level,
|
226 |
+
grid_deformation=deformation,
|
227 |
+
)
|
228 |
+
|
229 |
+
return mesh
|
spar3d/models/mesh.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Any, Dict, Optional
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import trimesh
|
10 |
+
from jaxtyping import Float, Integer
|
11 |
+
from torch import Tensor
|
12 |
+
|
13 |
+
from spar3d.models.utils import dot
|
14 |
+
|
15 |
+
try:
|
16 |
+
from uv_unwrapper import Unwrapper
|
17 |
+
except ImportError:
|
18 |
+
import logging
|
19 |
+
|
20 |
+
logging.warning(
|
21 |
+
"Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
|
22 |
+
)
|
23 |
+
# Exit early to avoid further errors
|
24 |
+
raise ImportError("uv_unwrapper not found")
|
25 |
+
|
26 |
+
try:
|
27 |
+
import gpytoolbox
|
28 |
+
|
29 |
+
TRIANGLE_REMESH_AVAILABLE = True
|
30 |
+
except ImportError:
|
31 |
+
TRIANGLE_REMESH_AVAILABLE = False
|
32 |
+
import logging
|
33 |
+
|
34 |
+
logging.warning(
|
35 |
+
"Could not import gpytoolbox. Triangle remeshing functionality will be disabled. "
|
36 |
+
"Install via `pip install gpytoolbox`"
|
37 |
+
)
|
38 |
+
|
39 |
+
try:
|
40 |
+
import pynim
|
41 |
+
|
42 |
+
QUAD_REMESH_AVAILABLE = True
|
43 |
+
except ImportError:
|
44 |
+
QUAD_REMESH_AVAILABLE = False
|
45 |
+
import logging
|
46 |
+
|
47 |
+
logging.warning(
|
48 |
+
"Could not import pynim. Quad remeshing functionality will be disabled. "
|
49 |
+
"Install via `pip install git+https://github.com/vork/PyNanoInstantMeshes.git@v0.0.3`"
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
class Mesh:
|
54 |
+
def __init__(
|
55 |
+
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
|
56 |
+
) -> None:
|
57 |
+
self.v_pos: Float[Tensor, "Nv 3"] = v_pos
|
58 |
+
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
|
59 |
+
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
|
60 |
+
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
|
61 |
+
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
|
62 |
+
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
63 |
+
self.extras: Dict[str, Any] = {}
|
64 |
+
for k, v in kwargs.items():
|
65 |
+
self.add_extra(k, v)
|
66 |
+
|
67 |
+
self.unwrapper = Unwrapper()
|
68 |
+
|
69 |
+
def add_extra(self, k, v) -> None:
|
70 |
+
self.extras[k] = v
|
71 |
+
|
72 |
+
@property
|
73 |
+
def requires_grad(self):
|
74 |
+
return self.v_pos.requires_grad
|
75 |
+
|
76 |
+
@property
|
77 |
+
def v_nrm(self):
|
78 |
+
if self._v_nrm is None:
|
79 |
+
self._v_nrm = self._compute_vertex_normal()
|
80 |
+
return self._v_nrm
|
81 |
+
|
82 |
+
@property
|
83 |
+
def v_tng(self):
|
84 |
+
if self._v_tng is None:
|
85 |
+
self._v_tng = self._compute_vertex_tangent()
|
86 |
+
return self._v_tng
|
87 |
+
|
88 |
+
@property
|
89 |
+
def v_tex(self):
|
90 |
+
if self._v_tex is None:
|
91 |
+
self.unwrap_uv()
|
92 |
+
return self._v_tex
|
93 |
+
|
94 |
+
@property
|
95 |
+
def edges(self):
|
96 |
+
if self._edges is None:
|
97 |
+
self._edges = self._compute_edges()
|
98 |
+
return self._edges
|
99 |
+
|
100 |
+
def _compute_vertex_normal(self):
|
101 |
+
i0 = self.t_pos_idx[:, 0]
|
102 |
+
i1 = self.t_pos_idx[:, 1]
|
103 |
+
i2 = self.t_pos_idx[:, 2]
|
104 |
+
|
105 |
+
v0 = self.v_pos[i0, :]
|
106 |
+
v1 = self.v_pos[i1, :]
|
107 |
+
v2 = self.v_pos[i2, :]
|
108 |
+
|
109 |
+
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
110 |
+
|
111 |
+
# Splat face normals to vertices
|
112 |
+
v_nrm = torch.zeros_like(self.v_pos)
|
113 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
114 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
115 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
116 |
+
|
117 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
118 |
+
v_nrm = torch.where(
|
119 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
120 |
+
)
|
121 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
122 |
+
|
123 |
+
if torch.is_anomaly_enabled():
|
124 |
+
assert torch.all(torch.isfinite(v_nrm))
|
125 |
+
|
126 |
+
return v_nrm
|
127 |
+
|
128 |
+
def _compute_vertex_tangent(self):
|
129 |
+
vn_idx = [None] * 3
|
130 |
+
pos = [None] * 3
|
131 |
+
tex = [None] * 3
|
132 |
+
for i in range(0, 3):
|
133 |
+
pos[i] = self.v_pos[self.t_pos_idx[:, i]]
|
134 |
+
tex[i] = self.v_tex[self.t_pos_idx[:, i]]
|
135 |
+
# t_nrm_idx is always the same as t_pos_idx
|
136 |
+
vn_idx[i] = self.t_pos_idx[:, i]
|
137 |
+
|
138 |
+
tangents = torch.zeros_like(self.v_nrm)
|
139 |
+
tansum = torch.zeros_like(self.v_nrm)
|
140 |
+
|
141 |
+
# Compute tangent space for each triangle
|
142 |
+
duv1 = tex[1] - tex[0]
|
143 |
+
duv2 = tex[2] - tex[0]
|
144 |
+
dpos1 = pos[1] - pos[0]
|
145 |
+
dpos2 = pos[2] - pos[0]
|
146 |
+
|
147 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
148 |
+
|
149 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
150 |
+
|
151 |
+
# Avoid division by zero for degenerated texture coordinates
|
152 |
+
denom_safe = denom.clip(1e-6)
|
153 |
+
tang = tng_nom / denom_safe
|
154 |
+
|
155 |
+
# Update all 3 vertices
|
156 |
+
for i in range(0, 3):
|
157 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
158 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
159 |
+
tansum.scatter_add_(
|
160 |
+
0, idx, torch.ones_like(tang)
|
161 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
162 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
163 |
+
# triangles influence the tangent space more
|
164 |
+
tangents = tangents / tansum
|
165 |
+
|
166 |
+
# Normalize and make sure tangent is perpendicular to normal
|
167 |
+
tangents = F.normalize(tangents, dim=1)
|
168 |
+
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
|
169 |
+
|
170 |
+
if torch.is_anomaly_enabled():
|
171 |
+
assert torch.all(torch.isfinite(tangents))
|
172 |
+
|
173 |
+
return tangents
|
174 |
+
|
175 |
+
def quad_remesh(
|
176 |
+
self,
|
177 |
+
quad_vertex_count: int = -1,
|
178 |
+
quad_rosy: int = 4,
|
179 |
+
quad_crease_angle: float = -1.0,
|
180 |
+
quad_smooth_iter: int = 2,
|
181 |
+
quad_align_to_boundaries: bool = False,
|
182 |
+
) -> Mesh:
|
183 |
+
if not QUAD_REMESH_AVAILABLE:
|
184 |
+
raise ImportError("Quad remeshing requires pynim to be installed")
|
185 |
+
if quad_vertex_count < 0:
|
186 |
+
quad_vertex_count = self.v_pos.shape[0]
|
187 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
|
188 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
|
189 |
+
|
190 |
+
new_vert, new_faces = pynim.remesh(
|
191 |
+
v_pos,
|
192 |
+
t_pos_idx,
|
193 |
+
quad_vertex_count // 4,
|
194 |
+
rosy=quad_rosy,
|
195 |
+
posy=4,
|
196 |
+
creaseAngle=quad_crease_angle,
|
197 |
+
align_to_boundaries=quad_align_to_boundaries,
|
198 |
+
smooth_iter=quad_smooth_iter,
|
199 |
+
deterministic=False,
|
200 |
+
)
|
201 |
+
|
202 |
+
# Briefly load in trimesh
|
203 |
+
mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
|
204 |
+
|
205 |
+
v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
|
206 |
+
t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
|
207 |
+
|
208 |
+
# Create new mesh
|
209 |
+
return Mesh(v_pos, t_pos_idx)
|
210 |
+
|
211 |
+
def triangle_remesh(
|
212 |
+
self,
|
213 |
+
triangle_average_edge_length_multiplier: Optional[float] = None,
|
214 |
+
triangle_remesh_steps: int = 10,
|
215 |
+
triangle_vertex_count=-1,
|
216 |
+
):
|
217 |
+
if not TRIANGLE_REMESH_AVAILABLE:
|
218 |
+
raise ImportError("Triangle remeshing requires gpytoolbox to be installed")
|
219 |
+
if triangle_vertex_count > 0:
|
220 |
+
reduction = triangle_vertex_count / self.v_pos.shape[0]
|
221 |
+
print("Triangle reduction:", reduction)
|
222 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
|
223 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
|
224 |
+
if reduction > 1.0:
|
225 |
+
subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
|
226 |
+
print("Subdivide iters:", subdivide_iters)
|
227 |
+
v_pos, t_pos_idx = gpytoolbox.subdivide(
|
228 |
+
v_pos,
|
229 |
+
t_pos_idx,
|
230 |
+
iters=subdivide_iters,
|
231 |
+
)
|
232 |
+
reduction = triangle_vertex_count / v_pos.shape[0]
|
233 |
+
|
234 |
+
# Simplify
|
235 |
+
points_out, faces_out, _, _ = gpytoolbox.decimate(
|
236 |
+
v_pos,
|
237 |
+
t_pos_idx,
|
238 |
+
face_ratio=reduction,
|
239 |
+
)
|
240 |
+
|
241 |
+
# Convert back to torch
|
242 |
+
self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
|
243 |
+
self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
|
244 |
+
self._edges = None
|
245 |
+
triangle_average_edge_length_multiplier = None
|
246 |
+
|
247 |
+
edges = self.edges
|
248 |
+
if triangle_average_edge_length_multiplier is None:
|
249 |
+
h = None
|
250 |
+
else:
|
251 |
+
h = float(
|
252 |
+
torch.linalg.norm(
|
253 |
+
self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
|
254 |
+
)
|
255 |
+
.mean()
|
256 |
+
.item()
|
257 |
+
* triangle_average_edge_length_multiplier
|
258 |
+
)
|
259 |
+
|
260 |
+
# Convert to numpy
|
261 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
|
262 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
|
263 |
+
|
264 |
+
# Remesh
|
265 |
+
v_remesh, f_remesh = gpytoolbox.remesh_botsch(
|
266 |
+
v_pos,
|
267 |
+
t_pos_idx,
|
268 |
+
triangle_remesh_steps,
|
269 |
+
h,
|
270 |
+
)
|
271 |
+
|
272 |
+
# Convert back to torch
|
273 |
+
v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
|
274 |
+
t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
|
275 |
+
|
276 |
+
# Create new mesh
|
277 |
+
return Mesh(v_pos, t_pos_idx)
|
278 |
+
|
279 |
+
@torch.no_grad()
|
280 |
+
def unwrap_uv(
|
281 |
+
self,
|
282 |
+
island_padding: float = 0.02,
|
283 |
+
) -> Mesh:
|
284 |
+
uv, indices = self.unwrapper(
|
285 |
+
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
|
286 |
+
)
|
287 |
+
|
288 |
+
# Do store per vertex UVs.
|
289 |
+
# This means we need to duplicate some vertices at the seams
|
290 |
+
individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
|
291 |
+
individual_faces = torch.arange(
|
292 |
+
individual_vertices.shape[0],
|
293 |
+
device=individual_vertices.device,
|
294 |
+
dtype=self.t_pos_idx.dtype,
|
295 |
+
).reshape(-1, 3)
|
296 |
+
uv_flat = uv[indices].reshape((-1, 2))
|
297 |
+
# uv_flat[:, 1] = 1 - uv_flat[:, 1]
|
298 |
+
|
299 |
+
self.v_pos = individual_vertices
|
300 |
+
self.t_pos_idx = individual_faces
|
301 |
+
self._v_tex = uv_flat
|
302 |
+
self._v_nrm = self._compute_vertex_normal()
|
303 |
+
self._v_tng = self._compute_vertex_tangent()
|
304 |
+
|
305 |
+
def _compute_edges(self):
|
306 |
+
# Compute edges
|
307 |
+
edges = torch.cat(
|
308 |
+
[
|
309 |
+
self.t_pos_idx[:, [0, 1]],
|
310 |
+
self.t_pos_idx[:, [1, 2]],
|
311 |
+
self.t_pos_idx[:, [2, 0]],
|
312 |
+
],
|
313 |
+
dim=0,
|
314 |
+
)
|
315 |
+
edges = edges.sort()[0]
|
316 |
+
edges = torch.unique(edges, dim=0)
|
317 |
+
return edges
|
spar3d/models/network.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Callable, List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from jaxtyping import Float
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.autograd import Function
|
11 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
12 |
+
|
13 |
+
from spar3d.models.utils import BaseModule, normalize
|
14 |
+
from spar3d.utils import get_device
|
15 |
+
|
16 |
+
|
17 |
+
def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
|
18 |
+
def wrapper(fn):
|
19 |
+
if condition:
|
20 |
+
if len(kwargs) == 0:
|
21 |
+
return decorator_with_args
|
22 |
+
return decorator_with_args(*args, **kwargs)(fn)
|
23 |
+
else:
|
24 |
+
return fn
|
25 |
+
|
26 |
+
return wrapper
|
27 |
+
|
28 |
+
|
29 |
+
class PixelShuffleUpsampleNetwork(BaseModule):
|
30 |
+
@dataclass
|
31 |
+
class Config(BaseModule.Config):
|
32 |
+
in_channels: int = 1024
|
33 |
+
out_channels: int = 40
|
34 |
+
scale_factor: int = 4
|
35 |
+
|
36 |
+
conv_layers: int = 4
|
37 |
+
conv_kernel_size: int = 3
|
38 |
+
|
39 |
+
cfg: Config
|
40 |
+
|
41 |
+
def configure(self) -> None:
|
42 |
+
layers = []
|
43 |
+
output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
|
44 |
+
|
45 |
+
in_channels = self.cfg.in_channels
|
46 |
+
for i in range(self.cfg.conv_layers):
|
47 |
+
cur_out_channels = (
|
48 |
+
in_channels if i != self.cfg.conv_layers - 1 else output_channels
|
49 |
+
)
|
50 |
+
layers.append(
|
51 |
+
nn.Conv2d(
|
52 |
+
in_channels,
|
53 |
+
cur_out_channels,
|
54 |
+
self.cfg.conv_kernel_size,
|
55 |
+
padding=(self.cfg.conv_kernel_size - 1) // 2,
|
56 |
+
)
|
57 |
+
)
|
58 |
+
if i != self.cfg.conv_layers - 1:
|
59 |
+
layers.append(nn.ReLU(inplace=True))
|
60 |
+
|
61 |
+
layers.append(nn.PixelShuffle(self.cfg.scale_factor))
|
62 |
+
|
63 |
+
self.upsample = nn.Sequential(*layers)
|
64 |
+
|
65 |
+
def forward(
|
66 |
+
self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
|
67 |
+
) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
|
68 |
+
return rearrange(
|
69 |
+
self.upsample(
|
70 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
71 |
+
),
|
72 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
73 |
+
Np=3,
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
class _TruncExp(Function): # pylint: disable=abstract-method
|
78 |
+
# Implementation from torch-ngp:
|
79 |
+
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
80 |
+
@staticmethod
|
81 |
+
@conditional_decorator(
|
82 |
+
custom_fwd, "cuda" in get_device(), cast_inputs=torch.float32
|
83 |
+
)
|
84 |
+
def forward(ctx, x): # pylint: disable=arguments-differ
|
85 |
+
ctx.save_for_backward(x)
|
86 |
+
return torch.exp(x)
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
@conditional_decorator(custom_bwd, "cuda" in get_device())
|
90 |
+
def backward(ctx, g): # pylint: disable=arguments-differ
|
91 |
+
x = ctx.saved_tensors[0]
|
92 |
+
return g * torch.exp(torch.clamp(x, max=15))
|
93 |
+
|
94 |
+
|
95 |
+
trunc_exp = _TruncExp.apply
|
96 |
+
|
97 |
+
|
98 |
+
def get_activation(name) -> Callable:
|
99 |
+
if name is None:
|
100 |
+
return lambda x: x
|
101 |
+
name = name.lower()
|
102 |
+
if name == "none" or name == "linear" or name == "identity":
|
103 |
+
return lambda x: x
|
104 |
+
elif name == "lin2srgb":
|
105 |
+
return lambda x: torch.where(
|
106 |
+
x > 0.0031308,
|
107 |
+
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
|
108 |
+
12.92 * x,
|
109 |
+
).clamp(0.0, 1.0)
|
110 |
+
elif name == "exp":
|
111 |
+
return lambda x: torch.exp(x)
|
112 |
+
elif name == "shifted_exp":
|
113 |
+
return lambda x: torch.exp(x - 1.0)
|
114 |
+
elif name == "trunc_exp":
|
115 |
+
return trunc_exp
|
116 |
+
elif name == "shifted_trunc_exp":
|
117 |
+
return lambda x: trunc_exp(x - 1.0)
|
118 |
+
elif name == "sigmoid":
|
119 |
+
return lambda x: torch.sigmoid(x)
|
120 |
+
elif name == "tanh":
|
121 |
+
return lambda x: torch.tanh(x)
|
122 |
+
elif name == "shifted_softplus":
|
123 |
+
return lambda x: F.softplus(x - 1.0)
|
124 |
+
elif name == "scale_-11_01":
|
125 |
+
return lambda x: x * 0.5 + 0.5
|
126 |
+
elif name == "negative":
|
127 |
+
return lambda x: -x
|
128 |
+
elif name == "normalize_channel_last":
|
129 |
+
return lambda x: normalize(x)
|
130 |
+
elif name == "normalize_channel_first":
|
131 |
+
return lambda x: normalize(x, dim=1)
|
132 |
+
else:
|
133 |
+
try:
|
134 |
+
return getattr(F, name)
|
135 |
+
except AttributeError:
|
136 |
+
raise ValueError(f"Unknown activation function: {name}")
|
137 |
+
|
138 |
+
|
139 |
+
class LambdaModule(torch.nn.Module):
|
140 |
+
def __init__(self, lambd: Callable[[torch.Tensor], torch.Tensor]):
|
141 |
+
super().__init__()
|
142 |
+
self.lambd = lambd
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
return self.lambd(x)
|
146 |
+
|
147 |
+
|
148 |
+
def get_activation_module(name) -> torch.nn.Module:
|
149 |
+
return LambdaModule(get_activation(name))
|
150 |
+
|
151 |
+
|
152 |
+
@dataclass
|
153 |
+
class HeadSpec:
|
154 |
+
name: str
|
155 |
+
out_channels: int
|
156 |
+
n_hidden_layers: int
|
157 |
+
output_activation: Optional[str] = None
|
158 |
+
out_bias: float = 0.0
|
159 |
+
|
160 |
+
|
161 |
+
class MaterialMLP(BaseModule):
|
162 |
+
@dataclass
|
163 |
+
class Config(BaseModule.Config):
|
164 |
+
in_channels: int = 120
|
165 |
+
n_neurons: int = 64
|
166 |
+
activation: str = "silu"
|
167 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
168 |
+
|
169 |
+
cfg: Config
|
170 |
+
|
171 |
+
def configure(self) -> None:
|
172 |
+
assert len(self.cfg.heads) > 0
|
173 |
+
heads = {}
|
174 |
+
for head in self.cfg.heads:
|
175 |
+
head_layers = []
|
176 |
+
for i in range(head.n_hidden_layers):
|
177 |
+
head_layers += [
|
178 |
+
nn.Linear(
|
179 |
+
self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
|
180 |
+
self.cfg.n_neurons,
|
181 |
+
),
|
182 |
+
self.make_activation(self.cfg.activation),
|
183 |
+
]
|
184 |
+
head_layers += [
|
185 |
+
nn.Linear(
|
186 |
+
self.cfg.n_neurons,
|
187 |
+
head.out_channels,
|
188 |
+
),
|
189 |
+
]
|
190 |
+
heads[head.name] = nn.Sequential(*head_layers)
|
191 |
+
self.heads = nn.ModuleDict(heads)
|
192 |
+
|
193 |
+
def make_activation(self, activation):
|
194 |
+
if activation == "relu":
|
195 |
+
return nn.ReLU(inplace=True)
|
196 |
+
elif activation == "silu":
|
197 |
+
return nn.SiLU(inplace=True)
|
198 |
+
else:
|
199 |
+
raise NotImplementedError
|
200 |
+
|
201 |
+
def keys(self):
|
202 |
+
return self.heads.keys()
|
203 |
+
|
204 |
+
def forward(
|
205 |
+
self, x, include: Optional[List] = None, exclude: Optional[List] = None
|
206 |
+
):
|
207 |
+
if include is not None and exclude is not None:
|
208 |
+
raise ValueError("Cannot specify both include and exclude.")
|
209 |
+
if include is not None:
|
210 |
+
heads = [h for h in self.cfg.heads if h.name in include]
|
211 |
+
elif exclude is not None:
|
212 |
+
heads = [h for h in self.cfg.heads if h.name not in exclude]
|
213 |
+
else:
|
214 |
+
heads = self.cfg.heads
|
215 |
+
|
216 |
+
out = {
|
217 |
+
head.name: get_activation(head.output_activation)(
|
218 |
+
self.heads[head.name](x) + head.out_bias
|
219 |
+
)
|
220 |
+
for head in heads
|
221 |
+
}
|
222 |
+
|
223 |
+
return out
|
spar3d/models/tokenizers/dinov2.py
ADDED
@@ -0,0 +1,1196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch DINOv2 model."""
|
16 |
+
|
17 |
+
import collections.abc
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
from transformers.activations import ACT2FN
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BackboneOutput,
|
30 |
+
BaseModelOutput,
|
31 |
+
BaseModelOutputWithPooling,
|
32 |
+
ImageClassifierOutput,
|
33 |
+
)
|
34 |
+
from transformers.modeling_utils import PreTrainedModel
|
35 |
+
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
|
36 |
+
from transformers.pytorch_utils import (
|
37 |
+
find_pruneable_heads_and_indices,
|
38 |
+
prune_linear_layer,
|
39 |
+
)
|
40 |
+
from transformers.utils import (
|
41 |
+
add_code_sample_docstrings,
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
logging,
|
45 |
+
replace_return_docstrings,
|
46 |
+
)
|
47 |
+
from transformers.utils.backbone_utils import BackboneMixin
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
# General docstring
|
52 |
+
_CONFIG_FOR_DOC = "Dinov2Config"
|
53 |
+
|
54 |
+
# Base docstring
|
55 |
+
_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
|
56 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
|
57 |
+
|
58 |
+
# Image classification docstring
|
59 |
+
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
|
60 |
+
|
61 |
+
|
62 |
+
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
63 |
+
"facebook/dinov2-base",
|
64 |
+
# See all DINOv2 models at https://huggingface.co/models?filter=dinov2
|
65 |
+
]
|
66 |
+
|
67 |
+
|
68 |
+
class Dinov2Embeddings(nn.Module):
|
69 |
+
"""
|
70 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, config: Dinov2Config) -> None:
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
77 |
+
# register as mask token as it's not used in optimization
|
78 |
+
# to avoid the use of find_unused_parameters_true
|
79 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
|
80 |
+
self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
|
81 |
+
self.patch_embeddings = Dinov2PatchEmbeddings(config)
|
82 |
+
num_patches = self.patch_embeddings.num_patches
|
83 |
+
self.position_embeddings = nn.Parameter(
|
84 |
+
torch.randn(1, num_patches + 1, config.hidden_size)
|
85 |
+
)
|
86 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
87 |
+
self.config = config
|
88 |
+
|
89 |
+
def interpolate_pos_encoding(
|
90 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
91 |
+
) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
94 |
+
resolution images.
|
95 |
+
|
96 |
+
Source:
|
97 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
98 |
+
"""
|
99 |
+
|
100 |
+
num_patches = embeddings.shape[1] - 1
|
101 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
102 |
+
if num_patches == num_positions and height == width:
|
103 |
+
return self.position_embeddings
|
104 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
105 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
106 |
+
dim = embeddings.shape[-1]
|
107 |
+
height = height // self.config.patch_size
|
108 |
+
width = width // self.config.patch_size
|
109 |
+
# we add a small number to avoid floating point error in the interpolation
|
110 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
111 |
+
height, width = height + 0.1, width + 0.1
|
112 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
113 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
114 |
+
)
|
115 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
116 |
+
patch_pos_embed = nn.functional.interpolate(
|
117 |
+
patch_pos_embed,
|
118 |
+
scale_factor=(
|
119 |
+
height / math.sqrt(num_positions),
|
120 |
+
width / math.sqrt(num_positions),
|
121 |
+
),
|
122 |
+
mode="bicubic",
|
123 |
+
align_corners=False,
|
124 |
+
)
|
125 |
+
if (
|
126 |
+
int(height) != patch_pos_embed.shape[-2]
|
127 |
+
or int(width) != patch_pos_embed.shape[-1]
|
128 |
+
):
|
129 |
+
raise ValueError(
|
130 |
+
"Width or height does not match with the interpolated position embeddings"
|
131 |
+
)
|
132 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
133 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
pixel_values: torch.Tensor,
|
138 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
139 |
+
) -> torch.Tensor:
|
140 |
+
batch_size, _, height, width = pixel_values.shape
|
141 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
142 |
+
embeddings = patch_embeddings
|
143 |
+
|
144 |
+
if bool_masked_pos is not None:
|
145 |
+
embeddings = torch.where(
|
146 |
+
bool_masked_pos.unsqueeze(-1),
|
147 |
+
self.mask_token.to(embeddings.dtype).unsqueeze(0),
|
148 |
+
embeddings,
|
149 |
+
)
|
150 |
+
|
151 |
+
# add the [CLS] token to the embedded patch tokens
|
152 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
153 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
154 |
+
|
155 |
+
# add positional encoding to each token
|
156 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
157 |
+
embeddings, height, width
|
158 |
+
)
|
159 |
+
|
160 |
+
embeddings = self.dropout(embeddings)
|
161 |
+
|
162 |
+
return embeddings
|
163 |
+
|
164 |
+
|
165 |
+
class Dinov2PatchEmbeddings(nn.Module):
|
166 |
+
"""
|
167 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
168 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
169 |
+
Transformer.
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, config):
|
173 |
+
super().__init__()
|
174 |
+
image_size, patch_size = config.image_size, config.patch_size
|
175 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
176 |
+
|
177 |
+
image_size = (
|
178 |
+
image_size
|
179 |
+
if isinstance(image_size, collections.abc.Iterable)
|
180 |
+
else (image_size, image_size)
|
181 |
+
)
|
182 |
+
patch_size = (
|
183 |
+
patch_size
|
184 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
185 |
+
else (patch_size, patch_size)
|
186 |
+
)
|
187 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
188 |
+
image_size[0] // patch_size[0]
|
189 |
+
)
|
190 |
+
self.image_size = image_size
|
191 |
+
self.patch_size = patch_size
|
192 |
+
self.num_channels = num_channels
|
193 |
+
self.num_patches = num_patches
|
194 |
+
|
195 |
+
self.projection = nn.Conv2d(
|
196 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
197 |
+
)
|
198 |
+
|
199 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
200 |
+
"""
|
201 |
+
num_channels = pixel_values.shape[1]
|
202 |
+
if num_channels != self.num_channels:
|
203 |
+
raise ValueError(
|
204 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
205 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
206 |
+
)
|
207 |
+
"""
|
208 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
209 |
+
return embeddings
|
210 |
+
|
211 |
+
|
212 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
|
213 |
+
class Dinov2SelfAttention(nn.Module):
|
214 |
+
def __init__(self, config: Dinov2Config) -> None:
|
215 |
+
super().__init__()
|
216 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
217 |
+
config, "embedding_size"
|
218 |
+
):
|
219 |
+
raise ValueError(
|
220 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
221 |
+
f"heads {config.num_attention_heads}."
|
222 |
+
)
|
223 |
+
|
224 |
+
self.num_attention_heads = config.num_attention_heads
|
225 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
226 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
227 |
+
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
228 |
+
|
229 |
+
self.query = nn.Linear(
|
230 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
231 |
+
)
|
232 |
+
self.key = nn.Linear(
|
233 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
234 |
+
)
|
235 |
+
self.value = nn.Linear(
|
236 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
237 |
+
)
|
238 |
+
|
239 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
240 |
+
|
241 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
242 |
+
new_x_shape = x.size()[:-1] + (
|
243 |
+
self.num_attention_heads,
|
244 |
+
self.attention_head_size,
|
245 |
+
)
|
246 |
+
x = x.view(new_x_shape)
|
247 |
+
return x.permute(0, 2, 1, 3)
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
hidden_states,
|
252 |
+
head_mask: Optional[torch.Tensor] = None,
|
253 |
+
output_attentions: bool = False,
|
254 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
255 |
+
mixed_query_layer = self.query(hidden_states)
|
256 |
+
|
257 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
258 |
+
assert head_mask is None and not output_attentions
|
259 |
+
new_size = hidden_states.size()[:-1] + (
|
260 |
+
self.num_attention_heads,
|
261 |
+
self.attention_head_size,
|
262 |
+
)
|
263 |
+
key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
|
264 |
+
value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
|
265 |
+
query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
|
266 |
+
context_layer = F.scaled_dot_product_attention(
|
267 |
+
query_layer,
|
268 |
+
key_layer,
|
269 |
+
value_layer,
|
270 |
+
dropout_p=self.attention_probs_dropout_prob,
|
271 |
+
is_causal=False,
|
272 |
+
)
|
273 |
+
context_layer = context_layer.transpose(1, 2).reshape(
|
274 |
+
*hidden_states.size()[:-1], -1
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
278 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
279 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
280 |
+
|
281 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
282 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
283 |
+
|
284 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
285 |
+
|
286 |
+
# Normalize the attention scores to probabilities.
|
287 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
288 |
+
|
289 |
+
# This is actually dropping out entire tokens to attend to, which might
|
290 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
291 |
+
attention_probs = self.dropout(attention_probs)
|
292 |
+
|
293 |
+
# Mask heads if we want to
|
294 |
+
if head_mask is not None:
|
295 |
+
attention_probs = attention_probs * head_mask
|
296 |
+
|
297 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
298 |
+
|
299 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
300 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
301 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
302 |
+
|
303 |
+
outputs = (
|
304 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
305 |
+
)
|
306 |
+
|
307 |
+
return outputs
|
308 |
+
|
309 |
+
|
310 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
311 |
+
class Dinov2SelfOutput(nn.Module):
|
312 |
+
"""
|
313 |
+
The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
|
314 |
+
layernorm applied before each block.
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(self, config: Dinov2Config) -> None:
|
318 |
+
super().__init__()
|
319 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
320 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
321 |
+
|
322 |
+
def forward(
|
323 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
324 |
+
) -> torch.Tensor:
|
325 |
+
hidden_states = self.dense(hidden_states)
|
326 |
+
hidden_states = self.dropout(hidden_states)
|
327 |
+
|
328 |
+
return hidden_states
|
329 |
+
|
330 |
+
|
331 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
|
332 |
+
class Dinov2Attention(nn.Module):
|
333 |
+
def __init__(self, config: Dinov2Config) -> None:
|
334 |
+
super().__init__()
|
335 |
+
self.attention = Dinov2SelfAttention(config)
|
336 |
+
self.output = Dinov2SelfOutput(config)
|
337 |
+
self.pruned_heads = set()
|
338 |
+
|
339 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
340 |
+
if len(heads) == 0:
|
341 |
+
return
|
342 |
+
heads, index = find_pruneable_heads_and_indices(
|
343 |
+
heads,
|
344 |
+
self.attention.num_attention_heads,
|
345 |
+
self.attention.attention_head_size,
|
346 |
+
self.pruned_heads,
|
347 |
+
)
|
348 |
+
|
349 |
+
# Prune linear layers
|
350 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
351 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
352 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
353 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
354 |
+
|
355 |
+
# Update hyper params and store pruned heads
|
356 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(
|
357 |
+
heads
|
358 |
+
)
|
359 |
+
self.attention.all_head_size = (
|
360 |
+
self.attention.attention_head_size * self.attention.num_attention_heads
|
361 |
+
)
|
362 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
hidden_states: torch.Tensor,
|
367 |
+
head_mask: Optional[torch.Tensor] = None,
|
368 |
+
output_attentions: bool = False,
|
369 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
370 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
371 |
+
|
372 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
373 |
+
|
374 |
+
outputs = (attention_output,) + self_outputs[
|
375 |
+
1:
|
376 |
+
] # add attentions if we output them
|
377 |
+
return outputs
|
378 |
+
|
379 |
+
|
380 |
+
class Dinov2LayerScale(nn.Module):
|
381 |
+
def __init__(self, config) -> None:
|
382 |
+
super().__init__()
|
383 |
+
self.lambda1 = nn.Parameter(
|
384 |
+
config.layerscale_value * torch.ones(config.hidden_size)
|
385 |
+
)
|
386 |
+
|
387 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
388 |
+
return hidden_state * self.lambda1
|
389 |
+
|
390 |
+
|
391 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
392 |
+
def drop_path(
|
393 |
+
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
394 |
+
) -> torch.Tensor:
|
395 |
+
"""
|
396 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
397 |
+
|
398 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
399 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
400 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
401 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
402 |
+
argument.
|
403 |
+
"""
|
404 |
+
if drop_prob == 0.0 or not training:
|
405 |
+
return input
|
406 |
+
keep_prob = 1 - drop_prob
|
407 |
+
shape = (input.shape[0],) + (1,) * (
|
408 |
+
input.ndim - 1
|
409 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
410 |
+
random_tensor = keep_prob + torch.rand(
|
411 |
+
shape, dtype=input.dtype, device=input.device
|
412 |
+
)
|
413 |
+
random_tensor.floor_() # binarize
|
414 |
+
output = input.div(keep_prob) * random_tensor
|
415 |
+
return output
|
416 |
+
|
417 |
+
|
418 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
|
419 |
+
class Dinov2DropPath(nn.Module):
|
420 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
421 |
+
|
422 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
423 |
+
super().__init__()
|
424 |
+
self.drop_prob = drop_prob
|
425 |
+
|
426 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
427 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
428 |
+
|
429 |
+
def extra_repr(self) -> str:
|
430 |
+
return "p={}".format(self.drop_prob)
|
431 |
+
|
432 |
+
|
433 |
+
class Dinov2MLP(nn.Module):
|
434 |
+
def __init__(self, config) -> None:
|
435 |
+
super().__init__()
|
436 |
+
in_features = out_features = config.hidden_size
|
437 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
438 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
|
439 |
+
if isinstance(config.hidden_act, str):
|
440 |
+
self.activation = ACT2FN[config.hidden_act]
|
441 |
+
else:
|
442 |
+
self.activation = config.hidden_act
|
443 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
|
444 |
+
|
445 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
446 |
+
hidden_state = self.fc1(hidden_state)
|
447 |
+
hidden_state = self.activation(hidden_state)
|
448 |
+
hidden_state = self.fc2(hidden_state)
|
449 |
+
return hidden_state
|
450 |
+
|
451 |
+
|
452 |
+
class Dinov2SwiGLUFFN(nn.Module):
|
453 |
+
def __init__(self, config) -> None:
|
454 |
+
super().__init__()
|
455 |
+
in_features = out_features = config.hidden_size
|
456 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
457 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
458 |
+
|
459 |
+
self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
|
460 |
+
self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
|
461 |
+
|
462 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
463 |
+
hidden_state = self.weights_in(hidden_state)
|
464 |
+
x1, x2 = hidden_state.chunk(2, dim=-1)
|
465 |
+
hidden = nn.functional.silu(x1) * x2
|
466 |
+
return self.weights_out(hidden)
|
467 |
+
|
468 |
+
|
469 |
+
class Dinov2Layer(nn.Module):
|
470 |
+
"""This corresponds to the Block class in the original implementation."""
|
471 |
+
|
472 |
+
def __init__(self, config: Dinov2Config) -> None:
|
473 |
+
super().__init__()
|
474 |
+
|
475 |
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
476 |
+
self.norm1_modulation = None
|
477 |
+
self.attention = Dinov2Attention(config)
|
478 |
+
self.layer_scale1 = Dinov2LayerScale(config)
|
479 |
+
self.drop_path1 = (
|
480 |
+
Dinov2DropPath(config.drop_path_rate)
|
481 |
+
if config.drop_path_rate > 0.0
|
482 |
+
else nn.Identity()
|
483 |
+
)
|
484 |
+
|
485 |
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
486 |
+
self.norm2_modulation = None
|
487 |
+
|
488 |
+
if config.use_swiglu_ffn:
|
489 |
+
self.mlp = Dinov2SwiGLUFFN(config)
|
490 |
+
else:
|
491 |
+
self.mlp = Dinov2MLP(config)
|
492 |
+
self.layer_scale2 = Dinov2LayerScale(config)
|
493 |
+
self.drop_path2 = (
|
494 |
+
Dinov2DropPath(config.drop_path_rate)
|
495 |
+
if config.drop_path_rate > 0.0
|
496 |
+
else nn.Identity()
|
497 |
+
)
|
498 |
+
|
499 |
+
def forward(
|
500 |
+
self,
|
501 |
+
hidden_states: torch.Tensor,
|
502 |
+
head_mask: Optional[torch.Tensor] = None,
|
503 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
504 |
+
output_attentions: bool = False,
|
505 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
506 |
+
hidden_states_norm = self.norm1(hidden_states)
|
507 |
+
if self.norm1_modulation is not None:
|
508 |
+
assert modulation_cond is not None
|
509 |
+
hidden_states_norm = self.norm1_modulation(
|
510 |
+
hidden_states_norm, modulation_cond
|
511 |
+
)
|
512 |
+
self_attention_outputs = self.attention(
|
513 |
+
hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
|
514 |
+
head_mask,
|
515 |
+
output_attentions=output_attentions,
|
516 |
+
)
|
517 |
+
attention_output = self_attention_outputs[0]
|
518 |
+
|
519 |
+
attention_output = self.layer_scale1(attention_output)
|
520 |
+
outputs = self_attention_outputs[
|
521 |
+
1:
|
522 |
+
] # add self attentions if we output attention weights
|
523 |
+
|
524 |
+
# first residual connection
|
525 |
+
hidden_states = attention_output + hidden_states
|
526 |
+
|
527 |
+
# in Dinov2, layernorm is also applied after self-attention
|
528 |
+
layer_output = self.norm2(hidden_states)
|
529 |
+
if self.norm2_modulation is not None:
|
530 |
+
assert modulation_cond is not None
|
531 |
+
layer_output = self.norm2_modulation(layer_output, modulation_cond)
|
532 |
+
layer_output = self.mlp(layer_output)
|
533 |
+
layer_output = self.layer_scale2(layer_output)
|
534 |
+
|
535 |
+
# second residual connection
|
536 |
+
layer_output = layer_output + hidden_states
|
537 |
+
|
538 |
+
outputs = (layer_output,) + outputs
|
539 |
+
|
540 |
+
return outputs
|
541 |
+
|
542 |
+
def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
|
543 |
+
self.norm1_modulation = norm1_mod
|
544 |
+
self.norm2_modulation = norm2_mod
|
545 |
+
|
546 |
+
|
547 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
|
548 |
+
class Dinov2Encoder(nn.Module):
|
549 |
+
def __init__(self, config: Dinov2Config) -> None:
|
550 |
+
super().__init__()
|
551 |
+
self.config = config
|
552 |
+
self.layer = nn.ModuleList(
|
553 |
+
[Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
|
554 |
+
)
|
555 |
+
self.gradient_checkpointing = False
|
556 |
+
|
557 |
+
def forward(
|
558 |
+
self,
|
559 |
+
hidden_states: torch.Tensor,
|
560 |
+
head_mask: Optional[torch.Tensor] = None,
|
561 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
562 |
+
output_attentions: bool = False,
|
563 |
+
output_hidden_states: bool = False,
|
564 |
+
return_dict: bool = True,
|
565 |
+
) -> Union[tuple, BaseModelOutput]:
|
566 |
+
all_hidden_states = () if output_hidden_states else None
|
567 |
+
all_self_attentions = () if output_attentions else None
|
568 |
+
|
569 |
+
for i, layer_module in enumerate(self.layer):
|
570 |
+
if output_hidden_states:
|
571 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
572 |
+
|
573 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
574 |
+
|
575 |
+
if self.gradient_checkpointing and self.training:
|
576 |
+
|
577 |
+
def create_custom_forward(module):
|
578 |
+
def custom_forward(*inputs):
|
579 |
+
return module(*inputs, output_attentions)
|
580 |
+
|
581 |
+
return custom_forward
|
582 |
+
|
583 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
584 |
+
create_custom_forward(layer_module),
|
585 |
+
hidden_states,
|
586 |
+
layer_head_mask,
|
587 |
+
modulation_cond,
|
588 |
+
use_reentrant=False,
|
589 |
+
)
|
590 |
+
else:
|
591 |
+
layer_outputs = layer_module(
|
592 |
+
hidden_states, layer_head_mask, modulation_cond, output_attentions
|
593 |
+
)
|
594 |
+
|
595 |
+
hidden_states = layer_outputs[0]
|
596 |
+
|
597 |
+
if output_attentions:
|
598 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
599 |
+
|
600 |
+
if output_hidden_states:
|
601 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
602 |
+
|
603 |
+
if not return_dict:
|
604 |
+
return tuple(
|
605 |
+
v
|
606 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
607 |
+
if v is not None
|
608 |
+
)
|
609 |
+
return BaseModelOutput(
|
610 |
+
last_hidden_state=hidden_states,
|
611 |
+
hidden_states=all_hidden_states,
|
612 |
+
attentions=all_self_attentions,
|
613 |
+
)
|
614 |
+
|
615 |
+
|
616 |
+
class Dinov2PreTrainedModel(PreTrainedModel):
|
617 |
+
"""
|
618 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
619 |
+
models.
|
620 |
+
"""
|
621 |
+
|
622 |
+
config_class = Dinov2Config
|
623 |
+
base_model_prefix = "dinov2"
|
624 |
+
main_input_name = "pixel_values"
|
625 |
+
supports_gradient_checkpointing = True
|
626 |
+
|
627 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
628 |
+
"""Initialize the weights"""
|
629 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
630 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
631 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
632 |
+
module.weight.data = nn.init.trunc_normal_(
|
633 |
+
module.weight.data.to(torch.float32),
|
634 |
+
mean=0.0,
|
635 |
+
std=self.config.initializer_range,
|
636 |
+
).to(module.weight.dtype)
|
637 |
+
if module.bias is not None:
|
638 |
+
module.bias.data.zero_()
|
639 |
+
elif isinstance(module, nn.LayerNorm):
|
640 |
+
module.bias.data.zero_()
|
641 |
+
module.weight.data.fill_(1.0)
|
642 |
+
elif isinstance(module, Dinov2Embeddings):
|
643 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
644 |
+
module.position_embeddings.data.to(torch.float32),
|
645 |
+
mean=0.0,
|
646 |
+
std=self.config.initializer_range,
|
647 |
+
).to(module.position_embeddings.dtype)
|
648 |
+
|
649 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
650 |
+
module.cls_token.data.to(torch.float32),
|
651 |
+
mean=0.0,
|
652 |
+
std=self.config.initializer_range,
|
653 |
+
).to(module.cls_token.dtype)
|
654 |
+
|
655 |
+
def _set_gradient_checkpointing(
|
656 |
+
self, module: Dinov2Encoder, value: bool = False
|
657 |
+
) -> None:
|
658 |
+
if isinstance(module, Dinov2Encoder):
|
659 |
+
module.gradient_checkpointing = value
|
660 |
+
|
661 |
+
|
662 |
+
DINOV2_START_DOCSTRING = r"""
|
663 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
664 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
665 |
+
behavior.
|
666 |
+
|
667 |
+
Parameters:
|
668 |
+
config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
|
669 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
670 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
671 |
+
"""
|
672 |
+
|
673 |
+
DINOV2_BASE_INPUTS_DOCSTRING = r"""
|
674 |
+
Args:
|
675 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
676 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
677 |
+
[`BitImageProcessor.preprocess`] for details.
|
678 |
+
|
679 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
|
680 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
|
681 |
+
pre-training.
|
682 |
+
|
683 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
684 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
685 |
+
|
686 |
+
- 1 indicates the head is **not masked**,
|
687 |
+
- 0 indicates the head is **masked**.
|
688 |
+
|
689 |
+
output_attentions (`bool`, *optional*):
|
690 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
691 |
+
tensors for more detail.
|
692 |
+
output_hidden_states (`bool`, *optional*):
|
693 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
694 |
+
more detail.
|
695 |
+
return_dict (`bool`, *optional*):
|
696 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
697 |
+
"""
|
698 |
+
|
699 |
+
DINOV2_INPUTS_DOCSTRING = r"""
|
700 |
+
Args:
|
701 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
702 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
703 |
+
[`BitImageProcessor.preprocess`] for details.
|
704 |
+
|
705 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
706 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
707 |
+
|
708 |
+
- 1 indicates the head is **not masked**,
|
709 |
+
- 0 indicates the head is **masked**.
|
710 |
+
|
711 |
+
output_attentions (`bool`, *optional*):
|
712 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
713 |
+
tensors for more detail.
|
714 |
+
output_hidden_states (`bool`, *optional*):
|
715 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
716 |
+
more detail.
|
717 |
+
return_dict (`bool`, *optional*):
|
718 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
719 |
+
"""
|
720 |
+
|
721 |
+
|
722 |
+
@dataclass
|
723 |
+
class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
|
724 |
+
patch_embeddings: Optional[torch.FloatTensor] = None
|
725 |
+
|
726 |
+
|
727 |
+
@add_start_docstrings(
|
728 |
+
"The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
|
729 |
+
DINOV2_START_DOCSTRING,
|
730 |
+
)
|
731 |
+
class Dinov2Model(Dinov2PreTrainedModel):
|
732 |
+
def __init__(self, config: Dinov2Config):
|
733 |
+
super().__init__(config)
|
734 |
+
self.config = config
|
735 |
+
|
736 |
+
self.embeddings = Dinov2Embeddings(config)
|
737 |
+
self.encoder = Dinov2Encoder(config)
|
738 |
+
|
739 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
740 |
+
|
741 |
+
# Initialize weights and apply final processing
|
742 |
+
self.post_init()
|
743 |
+
|
744 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
745 |
+
return self.embeddings.patch_embeddings
|
746 |
+
|
747 |
+
def expand_input_channels(self, extra_input_channels: int) -> None:
|
748 |
+
if extra_input_channels == 0:
|
749 |
+
return
|
750 |
+
conv_old = self.embeddings.patch_embeddings.projection
|
751 |
+
conv_new = nn.Conv2d(
|
752 |
+
self.config.num_channels + extra_input_channels,
|
753 |
+
self.config.hidden_size,
|
754 |
+
kernel_size=self.config.patch_size,
|
755 |
+
stride=self.config.patch_size,
|
756 |
+
).to(self.device)
|
757 |
+
with torch.no_grad():
|
758 |
+
conv_new.weight[:, :3] = conv_old.weight
|
759 |
+
conv_new.bias = conv_old.bias
|
760 |
+
self.embeddings.patch_embeddings.projection = conv_new
|
761 |
+
del conv_old
|
762 |
+
|
763 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
764 |
+
"""
|
765 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
766 |
+
class PreTrainedModel
|
767 |
+
"""
|
768 |
+
for layer, heads in heads_to_prune.items():
|
769 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
770 |
+
|
771 |
+
@add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
|
772 |
+
@add_code_sample_docstrings(
|
773 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
774 |
+
output_type=BaseModelOutputWithPooling,
|
775 |
+
config_class=_CONFIG_FOR_DOC,
|
776 |
+
modality="vision",
|
777 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
778 |
+
)
|
779 |
+
def forward(
|
780 |
+
self,
|
781 |
+
pixel_values: Optional[torch.Tensor] = None,
|
782 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
783 |
+
head_mask: Optional[torch.Tensor] = None,
|
784 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
785 |
+
output_attentions: Optional[bool] = None,
|
786 |
+
output_hidden_states: Optional[bool] = None,
|
787 |
+
return_dict: Optional[bool] = None,
|
788 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
789 |
+
output_attentions = (
|
790 |
+
output_attentions
|
791 |
+
if output_attentions is not None
|
792 |
+
else self.config.output_attentions
|
793 |
+
)
|
794 |
+
output_hidden_states = (
|
795 |
+
output_hidden_states
|
796 |
+
if output_hidden_states is not None
|
797 |
+
else self.config.output_hidden_states
|
798 |
+
)
|
799 |
+
return_dict = (
|
800 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
801 |
+
)
|
802 |
+
|
803 |
+
if pixel_values is None:
|
804 |
+
raise ValueError("You have to specify pixel_values")
|
805 |
+
|
806 |
+
# Prepare head mask if needed
|
807 |
+
# 1.0 in head_mask indicate we keep the head
|
808 |
+
# attention_probs has shape bsz x n_heads x N x N
|
809 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
810 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
811 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
812 |
+
|
813 |
+
embedding_output = self.embeddings(
|
814 |
+
pixel_values, bool_masked_pos=bool_masked_pos
|
815 |
+
)
|
816 |
+
|
817 |
+
encoder_outputs = self.encoder(
|
818 |
+
embedding_output,
|
819 |
+
head_mask=head_mask,
|
820 |
+
modulation_cond=modulation_cond,
|
821 |
+
output_attentions=output_attentions,
|
822 |
+
output_hidden_states=output_hidden_states,
|
823 |
+
return_dict=return_dict,
|
824 |
+
)
|
825 |
+
sequence_output = encoder_outputs[0]
|
826 |
+
sequence_output = self.layernorm(sequence_output)
|
827 |
+
pooled_output = sequence_output[:, 0, :]
|
828 |
+
|
829 |
+
if not return_dict:
|
830 |
+
head_outputs = (sequence_output, pooled_output)
|
831 |
+
return head_outputs + encoder_outputs[1:]
|
832 |
+
|
833 |
+
return CustomBaseModelOutputWithPooling(
|
834 |
+
last_hidden_state=sequence_output,
|
835 |
+
pooler_output=pooled_output,
|
836 |
+
hidden_states=encoder_outputs.hidden_states,
|
837 |
+
attentions=encoder_outputs.attentions,
|
838 |
+
patch_embeddings=embedding_output,
|
839 |
+
)
|
840 |
+
|
841 |
+
def set_gradient_checkpointing(self, value: bool = False) -> None:
|
842 |
+
self._set_gradient_checkpointing(self.encoder, value)
|
843 |
+
|
844 |
+
|
845 |
+
@add_start_docstrings(
|
846 |
+
"""
|
847 |
+
Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
|
848 |
+
of the [CLS] token) e.g. for ImageNet.
|
849 |
+
""",
|
850 |
+
DINOV2_START_DOCSTRING,
|
851 |
+
)
|
852 |
+
class Dinov2ForImageClassification(Dinov2PreTrainedModel):
|
853 |
+
def __init__(self, config: Dinov2Config) -> None:
|
854 |
+
super().__init__(config)
|
855 |
+
|
856 |
+
self.num_labels = config.num_labels
|
857 |
+
self.dinov2 = Dinov2Model(config)
|
858 |
+
|
859 |
+
# Classifier head
|
860 |
+
self.classifier = (
|
861 |
+
nn.Linear(config.hidden_size * 2, config.num_labels)
|
862 |
+
if config.num_labels > 0
|
863 |
+
else nn.Identity()
|
864 |
+
)
|
865 |
+
|
866 |
+
# Initialize weights and apply final processing
|
867 |
+
self.post_init()
|
868 |
+
|
869 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
870 |
+
@add_code_sample_docstrings(
|
871 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
872 |
+
output_type=ImageClassifierOutput,
|
873 |
+
config_class=_CONFIG_FOR_DOC,
|
874 |
+
)
|
875 |
+
def forward(
|
876 |
+
self,
|
877 |
+
pixel_values: Optional[torch.Tensor] = None,
|
878 |
+
head_mask: Optional[torch.Tensor] = None,
|
879 |
+
labels: Optional[torch.Tensor] = None,
|
880 |
+
output_attentions: Optional[bool] = None,
|
881 |
+
output_hidden_states: Optional[bool] = None,
|
882 |
+
return_dict: Optional[bool] = None,
|
883 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
884 |
+
r"""
|
885 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
886 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
887 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
888 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
889 |
+
"""
|
890 |
+
return_dict = (
|
891 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
892 |
+
)
|
893 |
+
|
894 |
+
outputs = self.dinov2(
|
895 |
+
pixel_values,
|
896 |
+
head_mask=head_mask,
|
897 |
+
output_attentions=output_attentions,
|
898 |
+
output_hidden_states=output_hidden_states,
|
899 |
+
return_dict=return_dict,
|
900 |
+
)
|
901 |
+
|
902 |
+
sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
|
903 |
+
|
904 |
+
cls_token = sequence_output[:, 0]
|
905 |
+
patch_tokens = sequence_output[:, 1:]
|
906 |
+
|
907 |
+
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
|
908 |
+
|
909 |
+
logits = self.classifier(linear_input)
|
910 |
+
|
911 |
+
loss = None
|
912 |
+
if labels is not None:
|
913 |
+
# move labels to correct device to enable model parallelism
|
914 |
+
labels = labels.to(logits.device)
|
915 |
+
if self.config.problem_type is None:
|
916 |
+
if self.num_labels == 1:
|
917 |
+
self.config.problem_type = "regression"
|
918 |
+
elif self.num_labels > 1 and (
|
919 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
920 |
+
):
|
921 |
+
self.config.problem_type = "single_label_classification"
|
922 |
+
else:
|
923 |
+
self.config.problem_type = "multi_label_classification"
|
924 |
+
|
925 |
+
if self.config.problem_type == "regression":
|
926 |
+
loss_fct = MSELoss()
|
927 |
+
if self.num_labels == 1:
|
928 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
929 |
+
else:
|
930 |
+
loss = loss_fct(logits, labels)
|
931 |
+
elif self.config.problem_type == "single_label_classification":
|
932 |
+
loss_fct = CrossEntropyLoss()
|
933 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
934 |
+
elif self.config.problem_type == "multi_label_classification":
|
935 |
+
loss_fct = BCEWithLogitsLoss()
|
936 |
+
loss = loss_fct(logits, labels)
|
937 |
+
|
938 |
+
if not return_dict:
|
939 |
+
output = (logits,) + outputs[2:]
|
940 |
+
return ((loss,) + output) if loss is not None else output
|
941 |
+
|
942 |
+
return ImageClassifierOutput(
|
943 |
+
loss=loss,
|
944 |
+
logits=logits,
|
945 |
+
hidden_states=outputs.hidden_states,
|
946 |
+
attentions=outputs.attentions,
|
947 |
+
)
|
948 |
+
|
949 |
+
|
950 |
+
@add_start_docstrings(
|
951 |
+
"""
|
952 |
+
Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
|
953 |
+
""",
|
954 |
+
DINOV2_START_DOCSTRING,
|
955 |
+
)
|
956 |
+
class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
|
957 |
+
def __init__(self, config):
|
958 |
+
super().__init__(config)
|
959 |
+
super()._init_backbone(config)
|
960 |
+
|
961 |
+
self.num_features = [
|
962 |
+
config.hidden_size for _ in range(config.num_hidden_layers + 1)
|
963 |
+
]
|
964 |
+
self.embeddings = Dinov2Embeddings(config)
|
965 |
+
self.encoder = Dinov2Encoder(config)
|
966 |
+
|
967 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
968 |
+
|
969 |
+
# Initialize weights and apply final processing
|
970 |
+
self.post_init()
|
971 |
+
|
972 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
973 |
+
return self.embeddings.patch_embeddings
|
974 |
+
|
975 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
976 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
977 |
+
def forward(
|
978 |
+
self,
|
979 |
+
pixel_values: torch.Tensor,
|
980 |
+
output_hidden_states: Optional[bool] = None,
|
981 |
+
output_attentions: Optional[bool] = None,
|
982 |
+
return_dict: Optional[bool] = None,
|
983 |
+
) -> BackboneOutput:
|
984 |
+
"""
|
985 |
+
Returns:
|
986 |
+
|
987 |
+
Examples:
|
988 |
+
|
989 |
+
```python
|
990 |
+
>>> from transformers import AutoImageProcessor, AutoBackbone
|
991 |
+
>>> import torch
|
992 |
+
>>> from PIL import Image
|
993 |
+
>>> import requests
|
994 |
+
|
995 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
996 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
997 |
+
|
998 |
+
>>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
999 |
+
>>> model = AutoBackbone.from_pretrained(
|
1000 |
+
... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
|
1001 |
+
... )
|
1002 |
+
|
1003 |
+
>>> inputs = processor(image, return_tensors="pt")
|
1004 |
+
|
1005 |
+
>>> outputs = model(**inputs)
|
1006 |
+
>>> feature_maps = outputs.feature_maps
|
1007 |
+
>>> list(feature_maps[-1].shape)
|
1008 |
+
[1, 768, 16, 16]
|
1009 |
+
```"""
|
1010 |
+
return_dict = (
|
1011 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1012 |
+
)
|
1013 |
+
output_hidden_states = (
|
1014 |
+
output_hidden_states
|
1015 |
+
if output_hidden_states is not None
|
1016 |
+
else self.config.output_hidden_states
|
1017 |
+
)
|
1018 |
+
output_attentions = (
|
1019 |
+
output_attentions
|
1020 |
+
if output_attentions is not None
|
1021 |
+
else self.config.output_attentions
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
embedding_output = self.embeddings(pixel_values)
|
1025 |
+
|
1026 |
+
outputs = self.encoder(
|
1027 |
+
embedding_output,
|
1028 |
+
output_hidden_states=True,
|
1029 |
+
output_attentions=output_attentions,
|
1030 |
+
return_dict=return_dict,
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
1034 |
+
|
1035 |
+
feature_maps = ()
|
1036 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
1037 |
+
if stage in self.out_features:
|
1038 |
+
if self.config.apply_layernorm:
|
1039 |
+
hidden_state = self.layernorm(hidden_state)
|
1040 |
+
if self.config.reshape_hidden_states:
|
1041 |
+
batch_size, _, height, width = pixel_values.shape
|
1042 |
+
patch_size = self.config.patch_size
|
1043 |
+
hidden_state = hidden_state[:, 1:, :].reshape(
|
1044 |
+
batch_size, width // patch_size, height // patch_size, -1
|
1045 |
+
)
|
1046 |
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
1047 |
+
feature_maps += (hidden_state,)
|
1048 |
+
|
1049 |
+
if not return_dict:
|
1050 |
+
if output_hidden_states:
|
1051 |
+
output = (feature_maps,) + outputs[1:]
|
1052 |
+
else:
|
1053 |
+
output = (feature_maps,) + outputs[2:]
|
1054 |
+
return output
|
1055 |
+
|
1056 |
+
return BackboneOutput(
|
1057 |
+
feature_maps=feature_maps,
|
1058 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
1059 |
+
attentions=outputs.attentions if output_attentions else None,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
|
1063 |
+
class CustomPatchEmbeddings(nn.Module):
|
1064 |
+
"""
|
1065 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
1066 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
1067 |
+
Transformer.
|
1068 |
+
"""
|
1069 |
+
|
1070 |
+
def __init__(
|
1071 |
+
self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
|
1072 |
+
):
|
1073 |
+
super().__init__()
|
1074 |
+
|
1075 |
+
image_size = (
|
1076 |
+
image_size
|
1077 |
+
if isinstance(image_size, collections.abc.Iterable)
|
1078 |
+
else (image_size, image_size)
|
1079 |
+
)
|
1080 |
+
patch_size = (
|
1081 |
+
patch_size
|
1082 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
1083 |
+
else (patch_size, patch_size)
|
1084 |
+
)
|
1085 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
1086 |
+
image_size[0] // patch_size[0]
|
1087 |
+
)
|
1088 |
+
self.image_size = image_size
|
1089 |
+
self.patch_size = patch_size
|
1090 |
+
self.num_channels = num_channels
|
1091 |
+
self.num_patches = num_patches
|
1092 |
+
|
1093 |
+
self.projection = nn.Conv2d(
|
1094 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
1098 |
+
num_channels = pixel_values.shape[1]
|
1099 |
+
if num_channels != self.num_channels:
|
1100 |
+
raise ValueError(
|
1101 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
1102 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
1103 |
+
)
|
1104 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
1105 |
+
return embeddings
|
1106 |
+
|
1107 |
+
|
1108 |
+
class CustomEmbeddings(nn.Module):
|
1109 |
+
"""
|
1110 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
1111 |
+
"""
|
1112 |
+
|
1113 |
+
def __init__(
|
1114 |
+
self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
|
1115 |
+
) -> None:
|
1116 |
+
super().__init__()
|
1117 |
+
|
1118 |
+
self.image_size = image_size
|
1119 |
+
self.patch_size = patch_size
|
1120 |
+
self.num_channels = num_channels
|
1121 |
+
self.hidden_size = hidden_size
|
1122 |
+
|
1123 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
1124 |
+
|
1125 |
+
self.patch_embeddings = CustomPatchEmbeddings(
|
1126 |
+
image_size, patch_size, num_channels, hidden_size
|
1127 |
+
)
|
1128 |
+
num_patches = self.patch_embeddings.num_patches
|
1129 |
+
self.position_embeddings = nn.Parameter(
|
1130 |
+
torch.randn(1, num_patches + 1, self.hidden_size)
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
def interpolate_pos_encoding(
|
1134 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
1135 |
+
) -> torch.Tensor:
|
1136 |
+
"""
|
1137 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
1138 |
+
resolution images.
|
1139 |
+
|
1140 |
+
Source:
|
1141 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
1142 |
+
"""
|
1143 |
+
|
1144 |
+
num_patches = embeddings.shape[1] - 1
|
1145 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
1146 |
+
if num_patches == num_positions and height == width:
|
1147 |
+
return self.position_embeddings
|
1148 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
1149 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
1150 |
+
dim = embeddings.shape[-1]
|
1151 |
+
height = height // self.patch_size
|
1152 |
+
width = width // self.patch_size
|
1153 |
+
# we add a small number to avoid floating point error in the interpolation
|
1154 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
1155 |
+
height, width = height + 0.1, width + 0.1
|
1156 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
1157 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
1158 |
+
)
|
1159 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
1160 |
+
patch_pos_embed = nn.functional.interpolate(
|
1161 |
+
patch_pos_embed,
|
1162 |
+
scale_factor=(
|
1163 |
+
height / math.sqrt(num_positions),
|
1164 |
+
width / math.sqrt(num_positions),
|
1165 |
+
),
|
1166 |
+
mode="bicubic",
|
1167 |
+
align_corners=False,
|
1168 |
+
)
|
1169 |
+
if (
|
1170 |
+
int(height) != patch_pos_embed.shape[-2]
|
1171 |
+
or int(width) != patch_pos_embed.shape[-1]
|
1172 |
+
):
|
1173 |
+
raise ValueError(
|
1174 |
+
"Width or height does not match with the interpolated position embeddings"
|
1175 |
+
)
|
1176 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
1177 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
1178 |
+
|
1179 |
+
def forward(
|
1180 |
+
self,
|
1181 |
+
pixel_values: torch.Tensor,
|
1182 |
+
) -> torch.Tensor:
|
1183 |
+
batch_size, _, height, width = pixel_values.shape
|
1184 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
1185 |
+
embeddings = patch_embeddings
|
1186 |
+
|
1187 |
+
# add the [CLS] token to the embedded patch tokens
|
1188 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
1189 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
1190 |
+
|
1191 |
+
# add positional encoding to each token
|
1192 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
1193 |
+
embeddings, height, width
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
return embeddings
|
spar3d/models/tokenizers/image.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from spar3d.models.tokenizers.dinov2 import Dinov2Model
|
11 |
+
from spar3d.models.transformers.attention import Modulation
|
12 |
+
from spar3d.models.utils import BaseModule
|
13 |
+
|
14 |
+
|
15 |
+
class DINOV2SingleImageTokenizer(BaseModule):
|
16 |
+
@dataclass
|
17 |
+
class Config(BaseModule.Config):
|
18 |
+
pretrained_model_name_or_path: str = "facebook/dinov2-large"
|
19 |
+
width: int = 512
|
20 |
+
height: int = 512
|
21 |
+
modulation_cond_dim: int = 768
|
22 |
+
|
23 |
+
cfg: Config
|
24 |
+
|
25 |
+
def configure(self) -> None:
|
26 |
+
self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
|
27 |
+
|
28 |
+
for p in self.model.parameters():
|
29 |
+
p.requires_grad_(False)
|
30 |
+
self.model.eval()
|
31 |
+
|
32 |
+
self.model.set_gradient_checkpointing(False)
|
33 |
+
|
34 |
+
# add modulation
|
35 |
+
modulations = []
|
36 |
+
for layer in self.model.encoder.layer:
|
37 |
+
norm1_modulation = Modulation(
|
38 |
+
self.model.config.hidden_size,
|
39 |
+
self.cfg.modulation_cond_dim,
|
40 |
+
zero_init=True,
|
41 |
+
single_layer=True,
|
42 |
+
)
|
43 |
+
norm2_modulation = Modulation(
|
44 |
+
self.model.config.hidden_size,
|
45 |
+
self.cfg.modulation_cond_dim,
|
46 |
+
zero_init=True,
|
47 |
+
single_layer=True,
|
48 |
+
)
|
49 |
+
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
|
50 |
+
modulations += [norm1_modulation, norm2_modulation]
|
51 |
+
self.modulations = nn.ModuleList(modulations)
|
52 |
+
|
53 |
+
self.register_buffer(
|
54 |
+
"image_mean",
|
55 |
+
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
|
56 |
+
persistent=False,
|
57 |
+
)
|
58 |
+
self.register_buffer(
|
59 |
+
"image_std",
|
60 |
+
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
|
61 |
+
persistent=False,
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
images: Float[Tensor, "B *N C H W"],
|
67 |
+
modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
|
68 |
+
**kwargs,
|
69 |
+
) -> Float[Tensor, "B *N Ct Nt"]:
|
70 |
+
model = self.model
|
71 |
+
|
72 |
+
packed = False
|
73 |
+
if images.ndim == 4:
|
74 |
+
packed = True
|
75 |
+
images = images.unsqueeze(1)
|
76 |
+
if modulation_cond is not None:
|
77 |
+
assert modulation_cond.ndim == 2
|
78 |
+
modulation_cond = modulation_cond.unsqueeze(1)
|
79 |
+
|
80 |
+
batch_size, n_input_views = images.shape[:2]
|
81 |
+
images = (images - self.image_mean) / self.image_std
|
82 |
+
out = model(
|
83 |
+
rearrange(images, "B N C H W -> (B N) C H W"),
|
84 |
+
modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
|
85 |
+
if modulation_cond is not None
|
86 |
+
else None,
|
87 |
+
)
|
88 |
+
local_features = out.last_hidden_state
|
89 |
+
local_features = local_features.permute(0, 2, 1)
|
90 |
+
local_features = rearrange(
|
91 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
92 |
+
)
|
93 |
+
if packed:
|
94 |
+
local_features = local_features.squeeze(1)
|
95 |
+
|
96 |
+
return local_features
|
97 |
+
|
98 |
+
def detokenize(self, *args, **kwargs):
|
99 |
+
raise NotImplementedError
|
spar3d/models/tokenizers/point.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from jaxtyping import Float
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from spar3d.models.transformers.transformer_1d import Transformer1D
|
9 |
+
from spar3d.models.utils import BaseModule
|
10 |
+
|
11 |
+
|
12 |
+
class TransformerPointTokenizer(BaseModule):
|
13 |
+
@dataclass
|
14 |
+
class Config(BaseModule.Config):
|
15 |
+
num_attention_heads: int = 16
|
16 |
+
attention_head_dim: int = 64
|
17 |
+
in_channels: Optional[int] = 6
|
18 |
+
out_channels: Optional[int] = 1024
|
19 |
+
num_layers: int = 16
|
20 |
+
norm_num_groups: int = 32
|
21 |
+
attention_bias: bool = False
|
22 |
+
activation_fn: str = "geglu"
|
23 |
+
norm_elementwise_affine: bool = True
|
24 |
+
|
25 |
+
cfg: Config
|
26 |
+
|
27 |
+
def configure(self) -> None:
|
28 |
+
transformer_cfg = dict(self.cfg.copy())
|
29 |
+
# remove the non-transformer configs
|
30 |
+
transformer_cfg["in_channels"] = (
|
31 |
+
self.cfg.num_attention_heads * self.cfg.attention_head_dim
|
32 |
+
)
|
33 |
+
self.model = Transformer1D(transformer_cfg)
|
34 |
+
self.linear_in = torch.nn.Linear(
|
35 |
+
self.cfg.in_channels, transformer_cfg["in_channels"]
|
36 |
+
)
|
37 |
+
self.linear_out = torch.nn.Linear(
|
38 |
+
transformer_cfg["in_channels"], self.cfg.out_channels
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(
|
42 |
+
self, points: Float[Tensor, "B N Ci"], **kwargs
|
43 |
+
) -> Float[Tensor, "B N Cp"]:
|
44 |
+
assert points.ndim == 3
|
45 |
+
inputs = self.linear_in(points).permute(0, 2, 1) # B N Ci -> B Ci N
|
46 |
+
out = self.model(inputs).permute(0, 2, 1) # B Ci N -> B N Ci
|
47 |
+
out = self.linear_out(out) # B N Ci -> B N Co
|
48 |
+
return out
|
49 |
+
|
50 |
+
def detokenize(self, *args, **kwargs):
|
51 |
+
raise NotImplementedError
|
spar3d/models/tokenizers/triplane.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from jaxtyping import Float
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from spar3d.models.utils import BaseModule
|
11 |
+
|
12 |
+
|
13 |
+
class TriplaneLearnablePositionalEmbedding(BaseModule):
|
14 |
+
@dataclass
|
15 |
+
class Config(BaseModule.Config):
|
16 |
+
plane_size: int = 96
|
17 |
+
num_channels: int = 1024
|
18 |
+
|
19 |
+
cfg: Config
|
20 |
+
|
21 |
+
def configure(self) -> None:
|
22 |
+
self.embeddings = nn.Parameter(
|
23 |
+
torch.randn(
|
24 |
+
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
|
25 |
+
dtype=torch.float32,
|
26 |
+
)
|
27 |
+
* 1
|
28 |
+
/ math.sqrt(self.cfg.num_channels)
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
|
32 |
+
return rearrange(
|
33 |
+
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
|
34 |
+
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
|
35 |
+
)
|
36 |
+
|
37 |
+
def detokenize(
|
38 |
+
self, tokens: Float[Tensor, "B Ct Nt"]
|
39 |
+
) -> Float[Tensor, "B 3 Ct Hp Wp"]:
|
40 |
+
batch_size, Ct, Nt = tokens.shape
|
41 |
+
assert Nt == self.cfg.plane_size**2 * 3
|
42 |
+
assert Ct == self.cfg.num_channels
|
43 |
+
return rearrange(
|
44 |
+
tokens,
|
45 |
+
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
|
46 |
+
Np=3,
|
47 |
+
Hp=self.cfg.plane_size,
|
48 |
+
Wp=self.cfg.plane_size,
|
49 |
+
)
|