Upload folder using huggingface_hub
Browse files- .gitattributes +5 -0
- .gitignore +164 -0
- LICENSE +21 -0
- README.md +2 -6
- app.py +65 -39
- examples/captured.jpeg +3 -0
- examples/chair.png +0 -0
- examples/flamingo.png +0 -0
- examples/hamburger.png +0 -0
- examples/horse.png +0 -0
- examples/iso_house.png +3 -0
- examples/marble.png +0 -0
- examples/police_woman.png +0 -0
- examples/poly_fox.png +0 -0
- examples/robot.png +0 -0
- examples/stripes.png +0 -0
- examples/teapot.png +0 -0
- examples/tiger_girl.png +0 -0
- examples/unicorn.png +0 -0
- figures/comparison800.gif +3 -0
- figures/scatter-comparison.png +0 -0
- figures/teaser800.gif +3 -0
- figures/visual_comparisons.jpg +3 -0
- requirements.txt +2 -2
- run.py +162 -0
- tsr/models/isosurface.py +8 -4
- tsr/models/nerf_renderer.py +62 -55
- tsr/models/network_utils.py +1 -1
- tsr/models/tokenizers/image.py +0 -1
- tsr/models/transformer/attention.py +25 -0
- tsr/models/transformer/basic_transformer_block.py +31 -11
- tsr/models/transformer/transformer_1d.py +41 -38
- tsr/system.py +10 -27
- tsr/utils.py +0 -8
- upload.py +8 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/captured.jpeg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/iso_house.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
figures/comparison800.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
figures/teaser800.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
figures/visual_comparisons.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
# default output directory
|
163 |
+
output/
|
164 |
+
outputs/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Tripo AI & Stability AI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,9 @@
|
|
1 |
-
|
2 |
title: TripoSR
|
3 |
emoji: 🐳
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
sdk: docker
|
7 |
-
# sdk_version: 4.19.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
|
2 |
title: TripoSR
|
3 |
emoji: 🐳
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
sdk: docker
|
|
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
+
license: mit
|
|
|
|
|
|
app.py
CHANGED
@@ -13,14 +13,7 @@ from functools import partial
|
|
13 |
from tsr.system import TSR
|
14 |
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
HEADER = """
|
19 |
-
**TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
|
20 |
-
**Tips:**
|
21 |
-
1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
|
22 |
-
2. Please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
|
23 |
-
"""
|
24 |
|
25 |
|
26 |
if torch.cuda.is_available():
|
@@ -28,17 +21,14 @@ if torch.cuda.is_available():
|
|
28 |
else:
|
29 |
device = "cpu"
|
30 |
|
31 |
-
d = os.environ.get("DEVICE", None)
|
32 |
-
if d != None:
|
33 |
-
device = d
|
34 |
-
|
35 |
model = TSR.from_pretrained(
|
36 |
"stabilityai/TripoSR",
|
37 |
config_name="config.yaml",
|
38 |
weight_name="model.ckpt",
|
39 |
-
# token=HF_TOKEN
|
40 |
)
|
41 |
-
|
|
|
|
|
42 |
model.to(device)
|
43 |
|
44 |
rembg_session = rembg.new_session()
|
@@ -68,23 +58,36 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
|
|
68 |
return image
|
69 |
|
70 |
|
71 |
-
def generate(image):
|
72 |
scene_codes = model(image, device=device)
|
73 |
-
mesh = model.extract_mesh(scene_codes, resolution=
|
74 |
mesh = to_gradio_3d_orientation(mesh)
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
|
81 |
def run_example(image_pil):
|
82 |
preprocessed = preprocess(image_pil, False, 0.9)
|
83 |
-
|
84 |
-
return preprocessed,
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
with gr.Row(variant="panel"):
|
89 |
with gr.Column():
|
90 |
with gr.Row():
|
@@ -108,30 +111,51 @@ with gr.Blocks() as demo:
|
|
108 |
value=0.85,
|
109 |
step=0.05,
|
110 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
with gr.Row():
|
112 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
113 |
with gr.Column():
|
114 |
-
with gr.Tab("
|
115 |
-
|
116 |
-
label="Output Model",
|
117 |
interactive=False,
|
118 |
)
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
122 |
interactive=False,
|
123 |
)
|
|
|
124 |
with gr.Row(variant="panel"):
|
125 |
gr.Examples(
|
126 |
examples=[
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
],
|
129 |
inputs=[input_image],
|
130 |
-
outputs=[processed_image,
|
131 |
-
|
132 |
fn=partial(run_example),
|
133 |
label="Examples",
|
134 |
-
examples_per_page=20
|
135 |
)
|
136 |
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
137 |
fn=preprocess,
|
@@ -139,9 +163,11 @@ with gr.Blocks() as demo:
|
|
139 |
outputs=[processed_image],
|
140 |
).success(
|
141 |
fn=generate,
|
142 |
-
inputs=[processed_image],
|
143 |
-
outputs=[
|
144 |
)
|
145 |
|
|
|
|
|
146 |
demo.queue(max_size=10)
|
147 |
-
demo.launch()
|
|
|
13 |
from tsr.system import TSR
|
14 |
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
15 |
|
16 |
+
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
if torch.cuda.is_available():
|
|
|
21 |
else:
|
22 |
device = "cpu"
|
23 |
|
|
|
|
|
|
|
|
|
24 |
model = TSR.from_pretrained(
|
25 |
"stabilityai/TripoSR",
|
26 |
config_name="config.yaml",
|
27 |
weight_name="model.ckpt",
|
|
|
28 |
)
|
29 |
+
|
30 |
+
# adjust the chunk size to balance between speed and memory usage
|
31 |
+
model.renderer.set_chunk_size(8192)
|
32 |
model.to(device)
|
33 |
|
34 |
rembg_session = rembg.new_session()
|
|
|
58 |
return image
|
59 |
|
60 |
|
61 |
+
def generate(image, mc_resolution, formats=["obj", "glb"]):
|
62 |
scene_codes = model(image, device=device)
|
63 |
+
mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
|
64 |
mesh = to_gradio_3d_orientation(mesh)
|
65 |
+
rv = []
|
66 |
+
for format in formats:
|
67 |
+
mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
68 |
+
mesh.export(mesh_path.name)
|
69 |
+
rv.append(mesh_path.name)
|
70 |
+
return rv
|
71 |
+
|
72 |
|
73 |
def run_example(image_pil):
|
74 |
preprocessed = preprocess(image_pil, False, 0.9)
|
75 |
+
mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
|
76 |
+
return preprocessed, mesh_name_obj, mesh_name_glb
|
77 |
+
|
78 |
+
|
79 |
+
with gr.Blocks(title="TripoSR") as Demo:
|
80 |
+
gr.Markdown(
|
81 |
+
"""
|
82 |
+
# TripoSR Demo
|
83 |
+
[TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
|
84 |
+
|
85 |
+
**Tips:**
|
86 |
+
1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
|
87 |
+
2. It's better to disable "Remove Background" for the provided examples (except fot the last one) since they have been already preprocessed.
|
88 |
+
3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
|
89 |
+
"""
|
90 |
+
)
|
91 |
with gr.Row(variant="panel"):
|
92 |
with gr.Column():
|
93 |
with gr.Row():
|
|
|
111 |
value=0.85,
|
112 |
step=0.05,
|
113 |
)
|
114 |
+
mc_resolution = gr.Slider(
|
115 |
+
label="Marching Cubes Resolution",
|
116 |
+
minimum=32,
|
117 |
+
maximum=1024,
|
118 |
+
value=256,
|
119 |
+
step=32
|
120 |
+
)
|
121 |
with gr.Row():
|
122 |
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
123 |
with gr.Column():
|
124 |
+
with gr.Tab("OBJ"):
|
125 |
+
output_model_obj = gr.Model3D(
|
126 |
+
label="Output Model (OBJ Format)",
|
127 |
interactive=False,
|
128 |
)
|
129 |
+
gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
|
130 |
+
with gr.Tab("GLB"):
|
131 |
+
output_model_glb = gr.Model3D(
|
132 |
+
label="Output Model (GLB Format)",
|
133 |
interactive=False,
|
134 |
)
|
135 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
136 |
with gr.Row(variant="panel"):
|
137 |
gr.Examples(
|
138 |
examples=[
|
139 |
+
"examples/hamburger.png",
|
140 |
+
"examples/poly_fox.png",
|
141 |
+
"examples/robot.png",
|
142 |
+
"examples/teapot.png",
|
143 |
+
"examples/tiger_girl.png",
|
144 |
+
"examples/horse.png",
|
145 |
+
"examples/flamingo.png",
|
146 |
+
"examples/unicorn.png",
|
147 |
+
"examples/chair.png",
|
148 |
+
"examples/iso_house.png",
|
149 |
+
"examples/marble.png",
|
150 |
+
"examples/police_woman.png",
|
151 |
+
"examples/captured.jpeg",
|
152 |
],
|
153 |
inputs=[input_image],
|
154 |
+
outputs=[processed_image, output_model_obj, output_model_glb],
|
155 |
+
cache_examples=False,
|
156 |
fn=partial(run_example),
|
157 |
label="Examples",
|
158 |
+
examples_per_page=20,
|
159 |
)
|
160 |
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
161 |
fn=preprocess,
|
|
|
163 |
outputs=[processed_image],
|
164 |
).success(
|
165 |
fn=generate,
|
166 |
+
inputs=[processed_image, mc_resolution],
|
167 |
+
outputs=[output_model_obj, output_model_glb],
|
168 |
)
|
169 |
|
170 |
+
|
171 |
+
|
172 |
demo.queue(max_size=10)
|
173 |
+
demo.launch()
|
examples/captured.jpeg
ADDED
Git LFS Details
|
examples/chair.png
ADDED
examples/flamingo.png
ADDED
examples/hamburger.png
ADDED
examples/horse.png
ADDED
examples/iso_house.png
ADDED
Git LFS Details
|
examples/marble.png
ADDED
examples/police_woman.png
ADDED
examples/poly_fox.png
ADDED
examples/robot.png
ADDED
examples/stripes.png
ADDED
examples/teapot.png
ADDED
examples/tiger_girl.png
ADDED
examples/unicorn.png
ADDED
figures/comparison800.gif
ADDED
Git LFS Details
|
figures/scatter-comparison.png
ADDED
figures/teaser800.gif
ADDED
Git LFS Details
|
figures/visual_comparisons.jpg
ADDED
Git LFS Details
|
requirements.txt
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
omegaconf==2.3.0
|
2 |
Pillow==10.1.0
|
3 |
einops==0.7.0
|
4 |
-
|
5 |
-
git+https://github.com/cocktailpeanut/torchmcubes.git
|
6 |
transformers==4.35.0
|
7 |
trimesh==4.0.5
|
8 |
rembg
|
9 |
huggingface-hub
|
|
|
10 |
gradio
|
|
|
1 |
omegaconf==2.3.0
|
2 |
Pillow==10.1.0
|
3 |
einops==0.7.0
|
4 |
+
git+https://github.com/tatsy/torchmcubes.git
|
|
|
5 |
transformers==4.35.0
|
6 |
trimesh==4.0.5
|
7 |
rembg
|
8 |
huggingface-hub
|
9 |
+
imageio[ffmpeg]
|
10 |
gradio
|
run.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import rembg
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from tsr.system import TSR
|
12 |
+
from tsr.utils import remove_background, resize_foreground, save_video
|
13 |
+
|
14 |
+
|
15 |
+
class Timer:
|
16 |
+
def __init__(self):
|
17 |
+
self.items = {}
|
18 |
+
self.time_scale = 1000.0 # ms
|
19 |
+
self.time_unit = "ms"
|
20 |
+
|
21 |
+
def start(self, name: str) -> None:
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
torch.cuda.synchronize()
|
24 |
+
self.items[name] = time.time()
|
25 |
+
logging.info(f"{name} ...")
|
26 |
+
|
27 |
+
def end(self, name: str) -> float:
|
28 |
+
if name not in self.items:
|
29 |
+
return
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
torch.cuda.synchronize()
|
32 |
+
start_time = self.items.pop(name)
|
33 |
+
delta = time.time() - start_time
|
34 |
+
t = delta * self.time_scale
|
35 |
+
logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
|
36 |
+
|
37 |
+
|
38 |
+
timer = Timer()
|
39 |
+
|
40 |
+
|
41 |
+
logging.basicConfig(
|
42 |
+
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
43 |
+
)
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
|
46 |
+
parser.add_argument(
|
47 |
+
"--device",
|
48 |
+
default="cuda:0",
|
49 |
+
type=str,
|
50 |
+
help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--pretrained-model-name-or-path",
|
54 |
+
default="stabilityai/TripoSR",
|
55 |
+
type=str,
|
56 |
+
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--chunk-size",
|
60 |
+
default=8192,
|
61 |
+
type=int,
|
62 |
+
help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--mc-resolution",
|
66 |
+
default=256,
|
67 |
+
type=int,
|
68 |
+
help="Marching cubes grid resolution. Default: 256"
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--no-remove-bg",
|
72 |
+
action="store_true",
|
73 |
+
help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--foreground-ratio",
|
77 |
+
default=0.85,
|
78 |
+
type=float,
|
79 |
+
help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--output-dir",
|
83 |
+
default="output/",
|
84 |
+
type=str,
|
85 |
+
help="Output directory to save the results. Default: 'output/'",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--model-save-format",
|
89 |
+
default="obj",
|
90 |
+
type=str,
|
91 |
+
choices=["obj", "glb"],
|
92 |
+
help="Format to save the extracted mesh. Default: 'obj'",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--render",
|
96 |
+
action="store_true",
|
97 |
+
help="If specified, save a NeRF-rendered video. Default: false",
|
98 |
+
)
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
output_dir = args.output_dir
|
102 |
+
os.makedirs(output_dir, exist_ok=True)
|
103 |
+
|
104 |
+
device = args.device
|
105 |
+
if not torch.cuda.is_available():
|
106 |
+
device = "cpu"
|
107 |
+
|
108 |
+
timer.start("Initializing model")
|
109 |
+
model = TSR.from_pretrained(
|
110 |
+
args.pretrained_model_name_or_path,
|
111 |
+
config_name="config.yaml",
|
112 |
+
weight_name="model.ckpt",
|
113 |
+
)
|
114 |
+
model.renderer.set_chunk_size(args.chunk_size)
|
115 |
+
model.to(device)
|
116 |
+
timer.end("Initializing model")
|
117 |
+
|
118 |
+
timer.start("Processing images")
|
119 |
+
images = []
|
120 |
+
|
121 |
+
if args.no_remove_bg:
|
122 |
+
rembg_session = None
|
123 |
+
else:
|
124 |
+
rembg_session = rembg.new_session()
|
125 |
+
|
126 |
+
for i, image_path in enumerate(args.image):
|
127 |
+
if args.no_remove_bg:
|
128 |
+
image = np.array(Image.open(image_path).convert("RGB"))
|
129 |
+
else:
|
130 |
+
image = remove_background(Image.open(image_path), rembg_session)
|
131 |
+
image = resize_foreground(image, args.foreground_ratio)
|
132 |
+
image = np.array(image).astype(np.float32) / 255.0
|
133 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
134 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
135 |
+
if not os.path.exists(os.path.join(output_dir, str(i))):
|
136 |
+
os.makedirs(os.path.join(output_dir, str(i)))
|
137 |
+
image.save(os.path.join(output_dir, str(i), f"input.png"))
|
138 |
+
images.append(image)
|
139 |
+
timer.end("Processing images")
|
140 |
+
|
141 |
+
for i, image in enumerate(images):
|
142 |
+
logging.info(f"Running image {i + 1}/{len(images)} ...")
|
143 |
+
|
144 |
+
timer.start("Running model")
|
145 |
+
with torch.no_grad():
|
146 |
+
scene_codes = model([image], device=device)
|
147 |
+
timer.end("Running model")
|
148 |
+
|
149 |
+
if args.render:
|
150 |
+
timer.start("Rendering")
|
151 |
+
render_images = model.render(scene_codes, n_views=30, return_type="pil")
|
152 |
+
for ri, render_image in enumerate(render_images[0]):
|
153 |
+
render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
|
154 |
+
save_video(
|
155 |
+
render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
|
156 |
+
)
|
157 |
+
timer.end("Rendering")
|
158 |
+
|
159 |
+
timer.start("Exporting mesh")
|
160 |
+
meshes = model.extract_mesh(scene_codes, resolution=args.mc_resolution)
|
161 |
+
meshes[0].export(os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}"))
|
162 |
+
timer.end("Exporting mesh")
|
tsr/models/isosurface.py
CHANGED
@@ -7,7 +7,7 @@ from torchmcubes import marching_cubes
|
|
7 |
|
8 |
|
9 |
class IsosurfaceHelper(nn.Module):
|
10 |
-
points_range: Tuple[float, float] = (
|
11 |
|
12 |
@property
|
13 |
def grid_vertices(self) -> torch.FloatTensor:
|
@@ -41,8 +41,12 @@ class MarchingCubeHelper(IsosurfaceHelper):
|
|
41 |
self,
|
42 |
level: torch.FloatTensor,
|
43 |
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
44 |
-
level = level.view(self.resolution, self.resolution, self.resolution)
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
v_pos = v_pos[..., [2, 1, 0]]
|
47 |
-
v_pos = v_pos
|
48 |
return v_pos.to(level.device), t_pos_idx.to(level.device)
|
|
|
7 |
|
8 |
|
9 |
class IsosurfaceHelper(nn.Module):
|
10 |
+
points_range: Tuple[float, float] = (0, 1)
|
11 |
|
12 |
@property
|
13 |
def grid_vertices(self) -> torch.FloatTensor:
|
|
|
41 |
self,
|
42 |
level: torch.FloatTensor,
|
43 |
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
44 |
+
level = -level.view(self.resolution, self.resolution, self.resolution)
|
45 |
+
try:
|
46 |
+
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
|
47 |
+
except AttributeError:
|
48 |
+
print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
|
49 |
+
v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
|
50 |
v_pos = v_pos[..., [2, 1, 0]]
|
51 |
+
v_pos = v_pos / (self.resolution - 1.0)
|
52 |
return v_pos.to(level.device), t_pos_idx.to(level.device)
|
tsr/models/nerf_renderer.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
from typing import Dict, Optional
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange, reduce
|
|
|
7 |
|
8 |
from ..utils import (
|
9 |
BaseModule,
|
@@ -37,73 +38,79 @@ class TriplaneNeRFRenderer(BaseModule):
|
|
37 |
chunk_size >= 0
|
38 |
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
39 |
self.chunk_size = chunk_size
|
40 |
-
def make_step_grid(self,device, resolution: int, chunk_size: int = 32):
|
41 |
-
coords = torch.linspace(-1.0, 1.0, resolution, device = device)
|
42 |
-
x, y, z = torch.meshgrid(coords[0:chunk_size], coords, coords, indexing="ij")
|
43 |
-
x = x.reshape(-1, 1)
|
44 |
-
y = y.reshape(-1, 1)
|
45 |
-
z = z.reshape(-1, 1)
|
46 |
-
verts = torch.cat([x, y, z], dim = -1).view(-1, 3)
|
47 |
-
indices2D: torch.Tensor = torch.stack(
|
48 |
-
(verts[..., [0, 1]], verts[..., [0, 2]], verts[..., [1, 2]]),
|
49 |
-
dim=-3,
|
50 |
-
)
|
51 |
-
return indices2D
|
52 |
|
53 |
-
def
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
)
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
net_out: Dict[str, torch.Tensor] = decoder(out)
|
92 |
-
return net_out['density']
|
93 |
def query_triplane(
|
94 |
self,
|
95 |
decoder: torch.nn.Module,
|
96 |
positions: torch.Tensor,
|
97 |
triplane: torch.Tensor,
|
|
|
98 |
) -> Dict[str, torch.Tensor]:
|
99 |
input_shape = positions.shape[:-1]
|
100 |
positions = positions.view(-1, 3)
|
101 |
|
102 |
# positions in (-radius, radius)
|
103 |
# normalized to (-1, 1) for grid sample
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
107 |
|
108 |
def _query_chunk(x):
|
109 |
indices2D: torch.Tensor = torch.stack(
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
from typing import Dict, Optional
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange, reduce
|
7 |
+
from torchmcubes import marching_cubes
|
8 |
|
9 |
from ..utils import (
|
10 |
BaseModule,
|
|
|
38 |
chunk_size >= 0
|
39 |
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
40 |
self.chunk_size = chunk_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
def interpolate_triplane(self, triplane: torch.Tensor, resolution: int):
|
43 |
+
coords = torch.linspace(-1.0, 1.0, resolution, device = triplane.device)
|
44 |
+
x, y = torch.meshgrid(coords, coords, indexing="ij")
|
45 |
+
verts2D = torch.cat([x.view(resolution, resolution,1), y.view(resolution, resolution,1)], dim = -1)
|
46 |
+
verts2D = verts2D.expand(3, -1, -1, -1)
|
47 |
+
return F.grid_sample(triplane, verts2D, align_corners=False,mode="bilinear") # [3 40 H W] xy xz yz
|
48 |
+
|
49 |
+
def block_based_marchingcube(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, threshold, block_resolution = 128) -> torch.Tensor:
|
50 |
+
resolution += 1 # sample 1 more line of density, 1024 + 1 == 1025, 0 mapping to -1.0f, 512 mapping to 0.0f, 1025 mapping to 1.0f, for better floating point precision.
|
51 |
+
block_size = 2.0 * block_resolution / (resolution - 1)
|
52 |
+
voxel_size = block_size / block_resolution
|
53 |
+
interpolated = self.interpolate_triplane(triplane, resolution)
|
54 |
+
|
55 |
+
pos_list = []
|
56 |
+
indices_list = []
|
57 |
+
for x in range(0, resolution - 1, block_resolution):
|
58 |
+
size_x = resolution - x if x + block_resolution >= resolution else block_resolution + 1 # sample 1 more line of density, so marching cubes resolution match block_resolution
|
59 |
+
for y in range(0, resolution - 1, block_resolution):
|
60 |
+
size_y = resolution - y if y + block_resolution >= resolution else block_resolution + 1
|
61 |
+
for z in range(0, resolution - 1, block_resolution):
|
62 |
+
size_z = resolution - z if z + block_resolution >= resolution else block_resolution + 1
|
63 |
+
xyplane = interpolated[0:1, :, x:x+size_x, y:y+size_y].expand(size_z, -1, -1, -1, -1).permute(3, 4, 0, 1, 2)
|
64 |
+
xzplane = interpolated[1:2, :, x:x+size_x, z:z+size_z].expand(size_y, -1, -1, -1, -1).permute(3, 0, 4, 1, 2)
|
65 |
+
yzplane = interpolated[2:3, :, y:y+size_y, z:z+size_z].expand(size_x, -1, -1, -1, -1).permute(0, 3, 4, 1, 2)
|
66 |
+
sz = size_x * size_y * size_z
|
67 |
+
out = torch.cat([xyplane, xzplane, yzplane], dim=3).view(sz, 3, -1)
|
68 |
+
|
69 |
+
if self.cfg.feature_reduction == "concat":
|
70 |
+
out = out.view(sz, -1)
|
71 |
+
elif self.cfg.feature_reduction == "mean":
|
72 |
+
out = reduce(out, "N Np Cp -> N Cp", Np=3, reduction="mean")
|
73 |
+
else:
|
74 |
+
raise NotImplementedError
|
75 |
+
net_out = decoder(out)
|
76 |
+
out = None # discard samples
|
77 |
+
density = net_out["density"]
|
78 |
+
net_out = None # discard colors
|
79 |
+
density = get_activation(self.cfg.density_activation)(density + self.cfg.density_bias).view(size_x, size_y, size_z)
|
80 |
+
try: # now do the marching cube
|
81 |
+
v_pos, indices = marching_cubes(density.detach(), threshold)
|
82 |
+
except AttributeError:
|
83 |
+
print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
|
84 |
+
v_pos, indices = self.mc_func(density.detach().cpu(), 0.0)
|
85 |
+
offset = torch.tensor([x * voxel_size - 1.0, y * voxel_size - 1.0, z * voxel_size - 1.0], device = triplane.device)
|
86 |
+
v_pos = v_pos[..., [2, 1, 0]] * voxel_size + offset
|
87 |
+
|
88 |
+
indices_list.append(indices)
|
89 |
+
pos_list.append(v_pos)
|
90 |
+
|
91 |
+
vertex_count = 0
|
92 |
+
for i in range(0, len(pos_list)):
|
93 |
+
indices_list[i] += vertex_count
|
94 |
+
vertex_count += pos_list[i].size(0)
|
95 |
+
|
96 |
+
return torch.cat(pos_list), torch.cat(indices_list)
|
97 |
|
|
|
|
|
98 |
def query_triplane(
|
99 |
self,
|
100 |
decoder: torch.nn.Module,
|
101 |
positions: torch.Tensor,
|
102 |
triplane: torch.Tensor,
|
103 |
+
scale_pos = True
|
104 |
) -> Dict[str, torch.Tensor]:
|
105 |
input_shape = positions.shape[:-1]
|
106 |
positions = positions.view(-1, 3)
|
107 |
|
108 |
# positions in (-radius, radius)
|
109 |
# normalized to (-1, 1) for grid sample
|
110 |
+
if scale_pos:
|
111 |
+
positions = scale_tensor(
|
112 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
113 |
+
)
|
114 |
|
115 |
def _query_chunk(x):
|
116 |
indices2D: torch.Tensor = torch.stack(
|
tsr/models/network_utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
from typing import Optional
|
3 |
|
4 |
import torch
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
from typing import Optional
|
3 |
|
4 |
import torch
|
tsr/models/tokenizers/image.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from typing import Optional
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
|
|
1 |
from dataclasses import dataclass
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
tsr/models/transformer/attention.py
CHANGED
@@ -11,6 +11,31 @@
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from typing import Optional
|
15 |
|
16 |
import torch
|
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
from typing import Optional
|
40 |
|
41 |
import torch
|
tsr/models/transformer/basic_transformer_block.py
CHANGED
@@ -11,8 +11,32 @@
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
from typing import
|
16 |
|
17 |
import torch
|
18 |
import torch.nn.functional as F
|
@@ -32,8 +56,6 @@ class BasicTransformerBlock(nn.Module):
|
|
32 |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
33 |
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
34 |
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
35 |
-
num_embeds_ada_norm (:
|
36 |
-
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
37 |
attention_bias (:
|
38 |
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
39 |
only_cross_attention (`bool`, *optional*):
|
@@ -48,8 +70,6 @@ class BasicTransformerBlock(nn.Module):
|
|
48 |
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
49 |
final_dropout (`bool` *optional*, defaults to False):
|
50 |
Whether to apply a final dropout after the last feed-forward layer.
|
51 |
-
attention_type (`str`, *optional*, defaults to `"default"`):
|
52 |
-
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
53 |
"""
|
54 |
|
55 |
def __init__(
|
@@ -95,9 +115,9 @@ class BasicTransformerBlock(nn.Module):
|
|
95 |
|
96 |
self.attn2 = Attention(
|
97 |
query_dim=dim,
|
98 |
-
cross_attention_dim=
|
99 |
-
|
100 |
-
|
101 |
heads=num_attention_heads,
|
102 |
dim_head=attention_head_dim,
|
103 |
dropout=dropout,
|
@@ -139,9 +159,9 @@ class BasicTransformerBlock(nn.Module):
|
|
139 |
|
140 |
attn_output = self.attn1(
|
141 |
norm_hidden_states,
|
142 |
-
encoder_hidden_states=
|
143 |
-
|
144 |
-
|
145 |
attention_mask=attention_mask,
|
146 |
)
|
147 |
|
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
|
39 |
+
from typing import Optional
|
40 |
|
41 |
import torch
|
42 |
import torch.nn.functional as F
|
|
|
56 |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
57 |
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
58 |
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
|
|
|
|
59 |
attention_bias (:
|
60 |
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
61 |
only_cross_attention (`bool`, *optional*):
|
|
|
70 |
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
71 |
final_dropout (`bool` *optional*, defaults to False):
|
72 |
Whether to apply a final dropout after the last feed-forward layer.
|
|
|
|
|
73 |
"""
|
74 |
|
75 |
def __init__(
|
|
|
115 |
|
116 |
self.attn2 = Attention(
|
117 |
query_dim=dim,
|
118 |
+
cross_attention_dim=(
|
119 |
+
cross_attention_dim if not double_self_attention else None
|
120 |
+
),
|
121 |
heads=num_attention_heads,
|
122 |
dim_head=attention_head_dim,
|
123 |
dropout=dropout,
|
|
|
159 |
|
160 |
attn_output = self.attn1(
|
161 |
norm_hidden_states,
|
162 |
+
encoder_hidden_states=(
|
163 |
+
encoder_hidden_states if self.only_cross_attention else None
|
164 |
+
),
|
165 |
attention_mask=attention_mask,
|
166 |
)
|
167 |
|
tsr/models/transformer/transformer_1d.py
CHANGED
@@ -1,5 +1,43 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
@@ -10,28 +48,6 @@ from .basic_transformer_block import BasicTransformerBlock
|
|
10 |
|
11 |
|
12 |
class Transformer1D(BaseModule):
|
13 |
-
"""
|
14 |
-
A 1D Transformer model for sequence data.
|
15 |
-
|
16 |
-
Parameters:
|
17 |
-
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
18 |
-
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
19 |
-
in_channels (`int`, *optional*):
|
20 |
-
The number of channels in the input and output (specify if the input is **continuous**).
|
21 |
-
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
22 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
23 |
-
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
24 |
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
25 |
-
num_embeds_ada_norm ( `int`, *optional*):
|
26 |
-
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
27 |
-
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
28 |
-
added to the hidden states.
|
29 |
-
|
30 |
-
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
31 |
-
attention_bias (`bool`, *optional*):
|
32 |
-
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
33 |
-
"""
|
34 |
-
|
35 |
@dataclass
|
36 |
class Config(BaseModule.Config):
|
37 |
num_attention_heads: int = 16
|
@@ -119,15 +135,6 @@ class Transformer1D(BaseModule):
|
|
119 |
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
120 |
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
121 |
self-attention.
|
122 |
-
timestep ( `torch.LongTensor`, *optional*):
|
123 |
-
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
124 |
-
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
125 |
-
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
126 |
-
`AdaLayerZeroNorm`.
|
127 |
-
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
128 |
-
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
129 |
-
`self.processor` in
|
130 |
-
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
131 |
attention_mask ( `torch.Tensor`, *optional*):
|
132 |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
133 |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
@@ -140,13 +147,9 @@ class Transformer1D(BaseModule):
|
|
140 |
|
141 |
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
142 |
above. This bias will be added to the cross-attention scores.
|
143 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
144 |
-
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
145 |
-
tuple.
|
146 |
|
147 |
Returns:
|
148 |
-
|
149 |
-
`tuple` where the first element is the sample tensor.
|
150 |
"""
|
151 |
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
152 |
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# --------
|
16 |
+
#
|
17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
18 |
+
#
|
19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
20 |
+
#
|
21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
23 |
+
# in the Software without restriction, including without limitation the rights
|
24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
26 |
+
# furnished to do so, subject to the following conditions:
|
27 |
+
#
|
28 |
+
# The above copyright notice and this permission notice shall be included in all
|
29 |
+
# copies or substantial portions of the Software.
|
30 |
+
#
|
31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
# SOFTWARE.
|
38 |
+
|
39 |
+
from dataclasses import dataclass
|
40 |
+
from typing import Optional
|
41 |
|
42 |
import torch
|
43 |
import torch.nn.functional as F
|
|
|
48 |
|
49 |
|
50 |
class Transformer1D(BaseModule):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
@dataclass
|
52 |
class Config(BaseModule.Config):
|
53 |
num_attention_heads: int = 16
|
|
|
135 |
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
136 |
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
137 |
self-attention.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
attention_mask ( `torch.Tensor`, *optional*):
|
139 |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
140 |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
|
|
147 |
|
148 |
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
149 |
above. This bias will be added to the cross-attention scores.
|
|
|
|
|
|
|
150 |
|
151 |
Returns:
|
152 |
+
torch.FloatTensor
|
|
|
153 |
"""
|
154 |
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
155 |
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
tsr/system.py
CHANGED
@@ -13,7 +13,6 @@ from huggingface_hub import hf_hub_download
|
|
13 |
from omegaconf import OmegaConf
|
14 |
from PIL import Image
|
15 |
|
16 |
-
from .models.isosurface import MarchingCubeHelper
|
17 |
from .utils import (
|
18 |
BaseModule,
|
19 |
ImagePreprocessor,
|
@@ -50,17 +49,17 @@ class TSR(BaseModule):
|
|
50 |
|
51 |
@classmethod
|
52 |
def from_pretrained(
|
53 |
-
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
54 |
):
|
55 |
if os.path.isdir(pretrained_model_name_or_path):
|
56 |
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
57 |
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
58 |
else:
|
59 |
config_path = hf_hub_download(
|
60 |
-
repo_id=pretrained_model_name_or_path, filename=config_name
|
61 |
)
|
62 |
weight_path = hf_hub_download(
|
63 |
-
repo_id=pretrained_model_name_or_path, filename=weight_name
|
64 |
)
|
65 |
|
66 |
cfg = OmegaConf.load(config_path)
|
@@ -160,36 +159,20 @@ class TSR(BaseModule):
|
|
160 |
|
161 |
return images
|
162 |
|
163 |
-
def set_marching_cubes_resolution(self, resolution: int):
|
164 |
-
if (
|
165 |
-
self.isosurface_helper is not None
|
166 |
-
and self.isosurface_helper.resolution == resolution
|
167 |
-
):
|
168 |
-
return
|
169 |
-
self.isosurface_helper = MarchingCubeHelper(resolution)
|
170 |
-
|
171 |
def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
|
172 |
-
self.set_marching_cubes_resolution(resolution)
|
173 |
meshes = []
|
174 |
for scene_code in scene_codes:
|
175 |
with torch.no_grad():
|
176 |
-
|
177 |
-
self.decoder.to(scene_codes.device),
|
178 |
-
scene_code,
|
179 |
-
resolution
|
180 |
-
) - threshold
|
181 |
-
v_pos, t_pos_idx = self.isosurface_helper(density)
|
182 |
-
density = None
|
183 |
-
v_pos = v_pos.to(scene_codes.device)
|
184 |
-
color = self.renderer.query_triplane(
|
185 |
-
self.decoder.to(scene_codes.device),
|
186 |
-
v_pos,
|
187 |
scene_code,
|
188 |
-
|
|
|
|
|
|
|
189 |
v_pos = scale_tensor(
|
190 |
v_pos,
|
191 |
-
|
192 |
-
(-self.renderer.cfg.radius, self.renderer.cfg.radius)
|
193 |
)
|
194 |
mesh = trimesh.Trimesh(
|
195 |
vertices=v_pos.cpu().numpy(),
|
|
|
13 |
from omegaconf import OmegaConf
|
14 |
from PIL import Image
|
15 |
|
|
|
16 |
from .utils import (
|
17 |
BaseModule,
|
18 |
ImagePreprocessor,
|
|
|
49 |
|
50 |
@classmethod
|
51 |
def from_pretrained(
|
52 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
53 |
):
|
54 |
if os.path.isdir(pretrained_model_name_or_path):
|
55 |
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
56 |
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
57 |
else:
|
58 |
config_path = hf_hub_download(
|
59 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
60 |
)
|
61 |
weight_path = hf_hub_download(
|
62 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
63 |
)
|
64 |
|
65 |
cfg = OmegaConf.load(config_path)
|
|
|
159 |
|
160 |
return images
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
|
|
|
163 |
meshes = []
|
164 |
for scene_code in scene_codes:
|
165 |
with torch.no_grad():
|
166 |
+
v_pos, t_pos_idx = self.renderer.block_based_marchingcube(self.decoder.to(scene_codes.device),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
scene_code,
|
168 |
+
resolution,
|
169 |
+
threshold
|
170 |
+
)
|
171 |
+
color = self.renderer.query_triplane(self.decoder.to(scene_codes.device), v_pos.to(scene_codes.device), scene_code, False)["color"]
|
172 |
v_pos = scale_tensor(
|
173 |
v_pos,
|
174 |
+
(-1.0, 1.0),
|
175 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius)
|
176 |
)
|
177 |
mesh = trimesh.Trimesh(
|
178 |
vertices=v_pos.cpu().numpy(),
|
tsr/utils.py
CHANGED
@@ -300,7 +300,6 @@ def get_rays(
|
|
300 |
directions,
|
301 |
c2w,
|
302 |
keepdim=False,
|
303 |
-
noise_scale=0.0,
|
304 |
normalize=False,
|
305 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
306 |
# Rotate ray directions from camera coordinate to the world coordinate
|
@@ -331,12 +330,6 @@ def get_rays(
|
|
331 |
) # (B, H, W, 3)
|
332 |
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
333 |
|
334 |
-
# add camera noise to avoid grid-like artifect
|
335 |
-
# https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
|
336 |
-
if noise_scale > 0:
|
337 |
-
rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
|
338 |
-
rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
|
339 |
-
|
340 |
if normalize:
|
341 |
rays_d = F.normalize(rays_d, dim=-1)
|
342 |
if not keepdim:
|
@@ -477,6 +470,5 @@ def save_video(
|
|
477 |
|
478 |
def to_gradio_3d_orientation(mesh):
|
479 |
mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
|
480 |
-
mesh.apply_scale([1, 1, -1])
|
481 |
mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
|
482 |
return mesh
|
|
|
300 |
directions,
|
301 |
c2w,
|
302 |
keepdim=False,
|
|
|
303 |
normalize=False,
|
304 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
305 |
# Rotate ray directions from camera coordinate to the world coordinate
|
|
|
330 |
) # (B, H, W, 3)
|
331 |
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
if normalize:
|
334 |
rays_d = F.normalize(rays_d, dim=-1)
|
335 |
if not keepdim:
|
|
|
470 |
|
471 |
def to_gradio_3d_orientation(mesh):
|
472 |
mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
|
|
|
473 |
mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
|
474 |
return mesh
|
upload.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi
|
2 |
+
api = HfApi()
|
3 |
+
|
4 |
+
api.upload_folder(
|
5 |
+
folder_path="/workspaces/TripoSR",
|
6 |
+
repo_id="michaelj/TripoSR",
|
7 |
+
repo_type="space",
|
8 |
+
)
|