HouIP hysts HF staff commited on
Commit
57432d2
·
0 Parent(s):

Duplicate from hysts/Shap-E

Browse files

Co-authored-by: hysts <hysts@users.noreply.huggingface.co>

Files changed (15) hide show
  1. .gitattributes +34 -0
  2. .gitignore +164 -0
  3. .pre-commit-config.yaml +36 -0
  4. .style.yapf +5 -0
  5. LICENSE +21 -0
  6. LICENSE.Shap-E +21 -0
  7. README.md +18 -0
  8. app.py +28 -0
  9. app_image_to_3d.py +76 -0
  10. app_text_to_3d.py +89 -0
  11. model.py +148 -0
  12. requirements.txt +5 -0
  13. settings.py +7 -0
  14. style.css +13 -0
  15. utils.py +9 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_cached_examples/
2
+ shap_e_model_cache/
3
+ corgi.png
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ additional_dependencies: ['types-python-slugify']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE.Shap-E ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 OpenAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Shap-E
3
+ emoji: 🧢
4
+ colorFrom: yellow
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.34.0
8
+ python_version: 3.10.11
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ suggested_hardware: t4-small
13
+ duplicated_from: hysts/Shap-E
14
+ ---
15
+
16
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
17
+
18
+ https://arxiv.org/abs/2305.02463
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+
5
+ import gradio as gr
6
+ import torch
7
+
8
+ from app_image_to_3d import create_demo as create_demo_image_to_3d
9
+ from app_text_to_3d import create_demo as create_demo_text_to_3d
10
+ from model import Model
11
+
12
+ DESCRIPTION = '# [Shap-E](https://github.com/openai/shap-e)'
13
+
14
+ if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
15
+ DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
16
+ if not torch.cuda.is_available():
17
+ DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
18
+
19
+ model = Model()
20
+
21
+ with gr.Blocks(css='style.css') as demo:
22
+ gr.Markdown(DESCRIPTION)
23
+ with gr.Tabs():
24
+ with gr.Tab(label='Text to 3D'):
25
+ create_demo_text_to_3d(model)
26
+ with gr.Tab(label='Image to 3D'):
27
+ create_demo_image_to_3d(model)
28
+ demo.queue(max_size=10).launch()
app_image_to_3d.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import pathlib
4
+ import shlex
5
+ import subprocess
6
+
7
+ import gradio as gr
8
+
9
+ from model import Model
10
+ from settings import CACHE_EXAMPLES, MAX_SEED
11
+ from utils import randomize_seed_fn
12
+
13
+
14
+ def create_demo(model: Model) -> gr.Blocks:
15
+ if not pathlib.Path('corgi.png').exists():
16
+ subprocess.run(
17
+ shlex.split(
18
+ 'wget https://raw.githubusercontent.com/openai/shap-e/d99cedaea18e0989e340163dbaeb4b109fa9e8ec/shap_e/examples/example_data/corgi.png -O corgi.png'
19
+ ))
20
+ examples = ['corgi.png']
21
+
22
+ def process_example_fn(image_path: str) -> str:
23
+ return model.run_image(image_path)
24
+
25
+ with gr.Blocks() as demo:
26
+ with gr.Box():
27
+ image = gr.Image(label='Input image',
28
+ show_label=False,
29
+ type='filepath')
30
+ run_button = gr.Button('Run')
31
+ result = gr.Model3D(label='Result', show_label=False)
32
+ with gr.Accordion('Advanced options', open=False):
33
+ seed = gr.Slider(label='Seed',
34
+ minimum=0,
35
+ maximum=MAX_SEED,
36
+ step=1,
37
+ value=0)
38
+ randomize_seed = gr.Checkbox(label='Randomize seed',
39
+ value=True)
40
+ guidance_scale = gr.Slider(label='Guidance scale',
41
+ minimum=1,
42
+ maximum=20,
43
+ step=0.1,
44
+ value=3.0)
45
+ num_inference_steps = gr.Slider(
46
+ label='Number of inference steps',
47
+ minimum=1,
48
+ maximum=100,
49
+ step=1,
50
+ value=64)
51
+
52
+ gr.Examples(examples=examples,
53
+ inputs=image,
54
+ outputs=result,
55
+ fn=process_example_fn,
56
+ cache_examples=CACHE_EXAMPLES)
57
+
58
+ inputs = [
59
+ image,
60
+ seed,
61
+ guidance_scale,
62
+ num_inference_steps,
63
+ ]
64
+
65
+ run_button.click(
66
+ fn=randomize_seed_fn,
67
+ inputs=[seed, randomize_seed],
68
+ outputs=seed,
69
+ queue=False,
70
+ ).then(
71
+ fn=model.run_image,
72
+ inputs=inputs,
73
+ outputs=result,
74
+ api_name='image-to-3d',
75
+ )
76
+ return demo
app_text_to_3d.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+
5
+ from model import Model
6
+ from settings import CACHE_EXAMPLES, MAX_SEED
7
+ from utils import randomize_seed_fn
8
+
9
+
10
+ def create_demo(model: Model) -> gr.Blocks:
11
+ examples = [
12
+ 'A chair that looks like an avocado',
13
+ 'An airplane that looks like a banana',
14
+ 'A spaceship',
15
+ 'A birthday cupcake',
16
+ 'A chair that looks like a tree',
17
+ 'A green boot',
18
+ 'A penguin',
19
+ 'Ube ice cream cone',
20
+ 'A bowl of vegetables',
21
+ ]
22
+
23
+ def process_example_fn(prompt: str) -> str:
24
+ return model.run_text(prompt)
25
+
26
+ with gr.Blocks() as demo:
27
+ with gr.Box():
28
+ with gr.Row(elem_id='prompt-container'):
29
+ prompt = gr.Text(
30
+ label='Prompt',
31
+ show_label=False,
32
+ max_lines=1,
33
+ placeholder='Enter your prompt').style(container=False)
34
+ run_button = gr.Button('Run').style(full_width=False)
35
+ result = gr.Model3D(label='Result', show_label=False)
36
+ with gr.Accordion('Advanced options', open=False):
37
+ seed = gr.Slider(label='Seed',
38
+ minimum=0,
39
+ maximum=MAX_SEED,
40
+ step=1,
41
+ value=0)
42
+ randomize_seed = gr.Checkbox(label='Randomize seed',
43
+ value=True)
44
+ guidance_scale = gr.Slider(label='Guidance scale',
45
+ minimum=1,
46
+ maximum=20,
47
+ step=0.1,
48
+ value=15.0)
49
+ num_inference_steps = gr.Slider(
50
+ label='Number of inference steps',
51
+ minimum=1,
52
+ maximum=100,
53
+ step=1,
54
+ value=64)
55
+
56
+ gr.Examples(examples=examples,
57
+ inputs=prompt,
58
+ outputs=result,
59
+ fn=process_example_fn,
60
+ cache_examples=CACHE_EXAMPLES)
61
+
62
+ inputs = [
63
+ prompt,
64
+ seed,
65
+ guidance_scale,
66
+ num_inference_steps,
67
+ ]
68
+ prompt.submit(
69
+ fn=randomize_seed_fn,
70
+ inputs=[seed, randomize_seed],
71
+ outputs=seed,
72
+ queue=False,
73
+ ).then(
74
+ fn=model.run_text,
75
+ inputs=inputs,
76
+ outputs=result,
77
+ )
78
+ run_button.click(
79
+ fn=randomize_seed_fn,
80
+ inputs=[seed, randomize_seed],
81
+ outputs=seed,
82
+ queue=False,
83
+ ).then(
84
+ fn=model.run_text,
85
+ inputs=inputs,
86
+ outputs=result,
87
+ api_name='text-to-3d',
88
+ )
89
+ return demo
model.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+
3
+ import numpy as np
4
+ import torch
5
+ import trimesh
6
+ from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
7
+ from shap_e.diffusion.sample import sample_latents
8
+ from shap_e.models.download import load_config, load_model
9
+ from shap_e.models.nn.camera import (DifferentiableCameraBatch,
10
+ DifferentiableProjectiveCamera)
11
+ from shap_e.models.transmitter.base import Transmitter, VectorDecoder
12
+ from shap_e.rendering.torch_mesh import TorchMesh
13
+ from shap_e.util.collections import AttrDict
14
+ from shap_e.util.image_util import load_image
15
+
16
+
17
+ # Copied from https://github.com/openai/shap-e/blob/d99cedaea18e0989e340163dbaeb4b109fa9e8ec/shap_e/util/notebooks.py#L15-L42
18
+ def create_pan_cameras(size: int,
19
+ device: torch.device) -> DifferentiableCameraBatch:
20
+ origins = []
21
+ xs = []
22
+ ys = []
23
+ zs = []
24
+ for theta in np.linspace(0, 2 * np.pi, num=20):
25
+ z = np.array([np.sin(theta), np.cos(theta), -0.5])
26
+ z /= np.sqrt(np.sum(z**2))
27
+ origin = -z * 4
28
+ x = np.array([np.cos(theta), -np.sin(theta), 0.0])
29
+ y = np.cross(z, x)
30
+ origins.append(origin)
31
+ xs.append(x)
32
+ ys.append(y)
33
+ zs.append(z)
34
+ return DifferentiableCameraBatch(
35
+ shape=(1, len(xs)),
36
+ flat_camera=DifferentiableProjectiveCamera(
37
+ origin=torch.from_numpy(np.stack(origins,
38
+ axis=0)).float().to(device),
39
+ x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
40
+ y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
41
+ z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
42
+ width=size,
43
+ height=size,
44
+ x_fov=0.7,
45
+ y_fov=0.7,
46
+ ),
47
+ )
48
+
49
+
50
+ # Copied from https://github.com/openai/shap-e/blob/8625e7c15526d8510a2292f92165979268d0e945/shap_e/util/notebooks.py#LL64C1-L76C33
51
+ @torch.no_grad()
52
+ def decode_latent_mesh(
53
+ xm: Transmitter | VectorDecoder,
54
+ latent: torch.Tensor,
55
+ ) -> TorchMesh:
56
+ decoded = xm.renderer.render_views(
57
+ AttrDict(cameras=create_pan_cameras(
58
+ 2, latent.device)), # lowest resolution possible
59
+ params=(xm.encoder if isinstance(xm, Transmitter) else
60
+ xm).bottleneck_to_params(latent[None]),
61
+ options=AttrDict(rendering_mode='stf', render_with_direction=False),
62
+ )
63
+ return decoded.raw_meshes[0]
64
+
65
+
66
+ class Model:
67
+ def __init__(self):
68
+ self.device = torch.device(
69
+ 'cuda' if torch.cuda.is_available() else 'cpu')
70
+ self.xm = load_model('transmitter', device=self.device)
71
+ self.diffusion = diffusion_from_config(load_config('diffusion'))
72
+ self.model_text = None
73
+ self.model_image = None
74
+
75
+ def load_model(self, model_name: str) -> None:
76
+ assert model_name in ['text300M', 'image300M']
77
+ if model_name == 'text300M' and self.model_text is None:
78
+ self.model_text = load_model(model_name, device=self.device)
79
+ elif model_name == 'image300M' and self.model_image is None:
80
+ self.model_image = load_model(model_name, device=self.device)
81
+
82
+ def to_glb(self, latent: torch.Tensor) -> str:
83
+ ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
84
+ delete=False,
85
+ mode='w+b')
86
+ decode_latent_mesh(self.xm, latent).tri_mesh().write_ply(ply_path)
87
+
88
+ mesh = trimesh.load(ply_path.name)
89
+ rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
90
+ mesh = mesh.apply_transform(rot)
91
+ rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
92
+ mesh = mesh.apply_transform(rot)
93
+
94
+ mesh_path = tempfile.NamedTemporaryFile(suffix='.glb', delete=False)
95
+ mesh.export(mesh_path.name, file_type='glb')
96
+
97
+ return mesh_path.name
98
+
99
+ def run_text(self,
100
+ prompt: str,
101
+ seed: int = 0,
102
+ guidance_scale: float = 15.0,
103
+ num_steps: int = 64) -> str:
104
+ self.load_model('text300M')
105
+ torch.manual_seed(seed)
106
+
107
+ latents = sample_latents(
108
+ batch_size=1,
109
+ model=self.model_text,
110
+ diffusion=self.diffusion,
111
+ guidance_scale=guidance_scale,
112
+ model_kwargs=dict(texts=[prompt]),
113
+ progress=True,
114
+ clip_denoised=True,
115
+ use_fp16=True,
116
+ use_karras=True,
117
+ karras_steps=num_steps,
118
+ sigma_min=1e-3,
119
+ sigma_max=160,
120
+ s_churn=0,
121
+ )
122
+ return self.to_glb(latents[0])
123
+
124
+ def run_image(self,
125
+ image_path: str,
126
+ seed: int = 0,
127
+ guidance_scale: float = 3.0,
128
+ num_steps: int = 64) -> str:
129
+ self.load_model('image300M')
130
+ torch.manual_seed(seed)
131
+
132
+ image = load_image(image_path)
133
+ latents = sample_latents(
134
+ batch_size=1,
135
+ model=self.model_image,
136
+ diffusion=self.diffusion,
137
+ guidance_scale=guidance_scale,
138
+ model_kwargs=dict(images=[image]),
139
+ progress=True,
140
+ clip_denoised=True,
141
+ use_fp16=True,
142
+ use_karras=True,
143
+ karras_steps=num_steps,
144
+ sigma_min=1e-3,
145
+ sigma_max=160,
146
+ s_churn=0,
147
+ )
148
+ return self.to_glb(latents[0])
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ git+https://github.com/openai/shap-e@8625e7c
2
+ gradio==3.32.0
3
+ torch==2.0.0
4
+ torchvision==0.15.1
5
+ trimesh==3.21.5
settings.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ CACHE_EXAMPLES = os.getenv('CACHE_EXAMPLES') == '1'
6
+
7
+ MAX_SEED = np.iinfo(np.int32).max
style.css ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #component-0 {
6
+ max-width: 730px;
7
+ margin: auto;
8
+ padding-top: 1.5rem;
9
+ }
10
+
11
+ #prompt-container {
12
+ gap: 0;
13
+ }
utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from settings import MAX_SEED
4
+
5
+
6
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
7
+ if randomize_seed:
8
+ seed = random.randint(0, MAX_SEED)
9
+ return seed