michaelj commited on
Commit
a12b8d1
1 Parent(s): 23f5383

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/captured.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ examples/iso_house.png filter=lfs diff=lfs merge=lfs -text
38
+ figures/comparison800.gif filter=lfs diff=lfs merge=lfs -text
39
+ figures/teaser800.gif filter=lfs diff=lfs merge=lfs -text
40
+ figures/visual_comparisons.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # default output directory
163
+ output/
164
+ outputs/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Tripo AI & Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,9 @@
1
- ---
2
  title: TripoSR
3
  emoji: 🐳
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: docker
7
- # sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+
2
  title: TripoSR
3
  emoji: 🐳
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: docker
 
7
  app_file: app.py
8
  pinned: false
9
+ license: mit
 
 
 
app.py CHANGED
@@ -13,14 +13,7 @@ from functools import partial
13
  from tsr.system import TSR
14
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
 
16
- #HF_TOKEN = os.getenv("HF_TOKEN")
17
-
18
- HEADER = """
19
- **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
20
- **Tips:**
21
- 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
22
- 2. Please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
23
- """
24
 
25
 
26
  if torch.cuda.is_available():
@@ -28,17 +21,14 @@ if torch.cuda.is_available():
28
  else:
29
  device = "cpu"
30
 
31
- d = os.environ.get("DEVICE", None)
32
- if d != None:
33
- device = d
34
-
35
  model = TSR.from_pretrained(
36
  "stabilityai/TripoSR",
37
  config_name="config.yaml",
38
  weight_name="model.ckpt",
39
- # token=HF_TOKEN
40
  )
41
- model.renderer.set_chunk_size(131072)
 
 
42
  model.to(device)
43
 
44
  rembg_session = rembg.new_session()
@@ -68,23 +58,36 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
68
  return image
69
 
70
 
71
- def generate(image):
72
  scene_codes = model(image, device=device)
73
- mesh = model.extract_mesh(scene_codes, resolution=1024)[0]
74
  mesh = to_gradio_3d_orientation(mesh)
75
- mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
76
- mesh_path2 = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
77
- mesh.export(mesh_path.name)
78
- mesh.export(mesh_path2.name)
79
- return mesh_path.name, mesh_path2.name
 
 
80
 
81
  def run_example(image_pil):
82
  preprocessed = preprocess(image_pil, False, 0.9)
83
- mesh_name, mesn_name2 = generate(preprocessed)
84
- return preprocessed, mesh_name, mesh_name2
85
-
86
- with gr.Blocks() as demo:
87
- gr.Markdown(HEADER)
 
 
 
 
 
 
 
 
 
 
 
88
  with gr.Row(variant="panel"):
89
  with gr.Column():
90
  with gr.Row():
@@ -108,30 +111,51 @@ with gr.Blocks() as demo:
108
  value=0.85,
109
  step=0.05,
110
  )
 
 
 
 
 
 
 
111
  with gr.Row():
112
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
113
  with gr.Column():
114
- with gr.Tab("obj"):
115
- output_model = gr.Model3D(
116
- label="Output Model",
117
  interactive=False,
118
  )
119
- with gr.Tab("glb"):
120
- output_model2 = gr.Model3D(
121
- label="Output Model",
 
122
  interactive=False,
123
  )
 
124
  with gr.Row(variant="panel"):
125
  gr.Examples(
126
  examples=[
127
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
 
 
 
 
 
 
 
 
 
 
 
 
128
  ],
129
  inputs=[input_image],
130
- outputs=[processed_image, output_model, output_model2],
131
- #cache_examples=True,
132
  fn=partial(run_example),
133
  label="Examples",
134
- examples_per_page=20
135
  )
136
  submit.click(fn=check_input_image, inputs=[input_image]).success(
137
  fn=preprocess,
@@ -139,9 +163,11 @@ with gr.Blocks() as demo:
139
  outputs=[processed_image],
140
  ).success(
141
  fn=generate,
142
- inputs=[processed_image],
143
- outputs=[output_model, output_model2],
144
  )
145
 
 
 
146
  demo.queue(max_size=10)
147
- demo.launch()
 
13
  from tsr.system import TSR
14
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
 
16
+ import argparse
 
 
 
 
 
 
 
17
 
18
 
19
  if torch.cuda.is_available():
 
21
  else:
22
  device = "cpu"
23
 
 
 
 
 
24
  model = TSR.from_pretrained(
25
  "stabilityai/TripoSR",
26
  config_name="config.yaml",
27
  weight_name="model.ckpt",
 
28
  )
29
+
30
+ # adjust the chunk size to balance between speed and memory usage
31
+ model.renderer.set_chunk_size(8192)
32
  model.to(device)
33
 
34
  rembg_session = rembg.new_session()
 
58
  return image
59
 
60
 
61
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
62
  scene_codes = model(image, device=device)
63
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
64
  mesh = to_gradio_3d_orientation(mesh)
65
+ rv = []
66
+ for format in formats:
67
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
68
+ mesh.export(mesh_path.name)
69
+ rv.append(mesh_path.name)
70
+ return rv
71
+
72
 
73
  def run_example(image_pil):
74
  preprocessed = preprocess(image_pil, False, 0.9)
75
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
76
+ return preprocessed, mesh_name_obj, mesh_name_glb
77
+
78
+
79
+ with gr.Blocks(title="TripoSR") as Demo:
80
+ gr.Markdown(
81
+ """
82
+ # TripoSR Demo
83
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
84
+
85
+ **Tips:**
86
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
87
+ 2. It's better to disable "Remove Background" for the provided examples (except fot the last one) since they have been already preprocessed.
88
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
89
+ """
90
+ )
91
  with gr.Row(variant="panel"):
92
  with gr.Column():
93
  with gr.Row():
 
111
  value=0.85,
112
  step=0.05,
113
  )
114
+ mc_resolution = gr.Slider(
115
+ label="Marching Cubes Resolution",
116
+ minimum=32,
117
+ maximum=1024,
118
+ value=256,
119
+ step=32
120
+ )
121
  with gr.Row():
122
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
123
  with gr.Column():
124
+ with gr.Tab("OBJ"):
125
+ output_model_obj = gr.Model3D(
126
+ label="Output Model (OBJ Format)",
127
  interactive=False,
128
  )
129
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
130
+ with gr.Tab("GLB"):
131
+ output_model_glb = gr.Model3D(
132
+ label="Output Model (GLB Format)",
133
  interactive=False,
134
  )
135
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
136
  with gr.Row(variant="panel"):
137
  gr.Examples(
138
  examples=[
139
+ "examples/hamburger.png",
140
+ "examples/poly_fox.png",
141
+ "examples/robot.png",
142
+ "examples/teapot.png",
143
+ "examples/tiger_girl.png",
144
+ "examples/horse.png",
145
+ "examples/flamingo.png",
146
+ "examples/unicorn.png",
147
+ "examples/chair.png",
148
+ "examples/iso_house.png",
149
+ "examples/marble.png",
150
+ "examples/police_woman.png",
151
+ "examples/captured.jpeg",
152
  ],
153
  inputs=[input_image],
154
+ outputs=[processed_image, output_model_obj, output_model_glb],
155
+ cache_examples=False,
156
  fn=partial(run_example),
157
  label="Examples",
158
+ examples_per_page=20,
159
  )
160
  submit.click(fn=check_input_image, inputs=[input_image]).success(
161
  fn=preprocess,
 
163
  outputs=[processed_image],
164
  ).success(
165
  fn=generate,
166
+ inputs=[processed_image, mc_resolution],
167
+ outputs=[output_model_obj, output_model_glb],
168
  )
169
 
170
+
171
+
172
  demo.queue(max_size=10)
173
+ demo.launch()
examples/captured.jpeg ADDED

Git LFS Details

  • SHA256: c6eb2768703a0e3d6034daa7fd5e0b286450b1077a90f36da8110749bb1cb8a8
  • Pointer size: 132 Bytes
  • Size of remote file: 5.94 MB
examples/chair.png ADDED
examples/flamingo.png ADDED
examples/hamburger.png ADDED
examples/horse.png ADDED
examples/iso_house.png ADDED

Git LFS Details

  • SHA256: b6063cbbc55b9aa4a4785ddbfcd13ca86fb07eca5a4ea7f9dda5eebcf7c17765
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
examples/marble.png ADDED
examples/police_woman.png ADDED
examples/poly_fox.png ADDED
examples/robot.png ADDED
examples/stripes.png ADDED
examples/teapot.png ADDED
examples/tiger_girl.png ADDED
examples/unicorn.png ADDED
figures/comparison800.gif ADDED

Git LFS Details

  • SHA256: 887e69297e4446f122801ff2cc39962eda0933906d7ed7be7abf659e721914be
  • Pointer size: 132 Bytes
  • Size of remote file: 8.87 MB
figures/scatter-comparison.png ADDED
figures/teaser800.gif ADDED

Git LFS Details

  • SHA256: 52ecc6ff24e008b0d28236425a1b59718931841f6fb9f5e6f8471829fc9bc292
  • Pointer size: 132 Bytes
  • Size of remote file: 3.84 MB
figures/visual_comparisons.jpg ADDED

Git LFS Details

  • SHA256: 019235d716d8832aaa659acd31cf17267af94df6b5a9beca2a7002b41d59c8db
  • Pointer size: 133 Bytes
  • Size of remote file: 10.3 MB
requirements.txt CHANGED
@@ -1,10 +1,10 @@
1
  omegaconf==2.3.0
2
  Pillow==10.1.0
3
  einops==0.7.0
4
- #git+https://github.com/tatsy/torchmcubes.git
5
- git+https://github.com/cocktailpeanut/torchmcubes.git
6
  transformers==4.35.0
7
  trimesh==4.0.5
8
  rembg
9
  huggingface-hub
 
10
  gradio
 
1
  omegaconf==2.3.0
2
  Pillow==10.1.0
3
  einops==0.7.0
4
+ git+https://github.com/tatsy/torchmcubes.git
 
5
  transformers==4.35.0
6
  trimesh==4.0.5
7
  rembg
8
  huggingface-hub
9
+ imageio[ffmpeg]
10
  gradio
run.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import rembg
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from tsr.system import TSR
12
+ from tsr.utils import remove_background, resize_foreground, save_video
13
+
14
+
15
+ class Timer:
16
+ def __init__(self):
17
+ self.items = {}
18
+ self.time_scale = 1000.0 # ms
19
+ self.time_unit = "ms"
20
+
21
+ def start(self, name: str) -> None:
22
+ if torch.cuda.is_available():
23
+ torch.cuda.synchronize()
24
+ self.items[name] = time.time()
25
+ logging.info(f"{name} ...")
26
+
27
+ def end(self, name: str) -> float:
28
+ if name not in self.items:
29
+ return
30
+ if torch.cuda.is_available():
31
+ torch.cuda.synchronize()
32
+ start_time = self.items.pop(name)
33
+ delta = time.time() - start_time
34
+ t = delta * self.time_scale
35
+ logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
36
+
37
+
38
+ timer = Timer()
39
+
40
+
41
+ logging.basicConfig(
42
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
43
+ )
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
46
+ parser.add_argument(
47
+ "--device",
48
+ default="cuda:0",
49
+ type=str,
50
+ help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
51
+ )
52
+ parser.add_argument(
53
+ "--pretrained-model-name-or-path",
54
+ default="stabilityai/TripoSR",
55
+ type=str,
56
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
57
+ )
58
+ parser.add_argument(
59
+ "--chunk-size",
60
+ default=8192,
61
+ type=int,
62
+ help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
63
+ )
64
+ parser.add_argument(
65
+ "--mc-resolution",
66
+ default=256,
67
+ type=int,
68
+ help="Marching cubes grid resolution. Default: 256"
69
+ )
70
+ parser.add_argument(
71
+ "--no-remove-bg",
72
+ action="store_true",
73
+ help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
74
+ )
75
+ parser.add_argument(
76
+ "--foreground-ratio",
77
+ default=0.85,
78
+ type=float,
79
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
80
+ )
81
+ parser.add_argument(
82
+ "--output-dir",
83
+ default="output/",
84
+ type=str,
85
+ help="Output directory to save the results. Default: 'output/'",
86
+ )
87
+ parser.add_argument(
88
+ "--model-save-format",
89
+ default="obj",
90
+ type=str,
91
+ choices=["obj", "glb"],
92
+ help="Format to save the extracted mesh. Default: 'obj'",
93
+ )
94
+ parser.add_argument(
95
+ "--render",
96
+ action="store_true",
97
+ help="If specified, save a NeRF-rendered video. Default: false",
98
+ )
99
+ args = parser.parse_args()
100
+
101
+ output_dir = args.output_dir
102
+ os.makedirs(output_dir, exist_ok=True)
103
+
104
+ device = args.device
105
+ if not torch.cuda.is_available():
106
+ device = "cpu"
107
+
108
+ timer.start("Initializing model")
109
+ model = TSR.from_pretrained(
110
+ args.pretrained_model_name_or_path,
111
+ config_name="config.yaml",
112
+ weight_name="model.ckpt",
113
+ )
114
+ model.renderer.set_chunk_size(args.chunk_size)
115
+ model.to(device)
116
+ timer.end("Initializing model")
117
+
118
+ timer.start("Processing images")
119
+ images = []
120
+
121
+ if args.no_remove_bg:
122
+ rembg_session = None
123
+ else:
124
+ rembg_session = rembg.new_session()
125
+
126
+ for i, image_path in enumerate(args.image):
127
+ if args.no_remove_bg:
128
+ image = np.array(Image.open(image_path).convert("RGB"))
129
+ else:
130
+ image = remove_background(Image.open(image_path), rembg_session)
131
+ image = resize_foreground(image, args.foreground_ratio)
132
+ image = np.array(image).astype(np.float32) / 255.0
133
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
134
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
135
+ if not os.path.exists(os.path.join(output_dir, str(i))):
136
+ os.makedirs(os.path.join(output_dir, str(i)))
137
+ image.save(os.path.join(output_dir, str(i), f"input.png"))
138
+ images.append(image)
139
+ timer.end("Processing images")
140
+
141
+ for i, image in enumerate(images):
142
+ logging.info(f"Running image {i + 1}/{len(images)} ...")
143
+
144
+ timer.start("Running model")
145
+ with torch.no_grad():
146
+ scene_codes = model([image], device=device)
147
+ timer.end("Running model")
148
+
149
+ if args.render:
150
+ timer.start("Rendering")
151
+ render_images = model.render(scene_codes, n_views=30, return_type="pil")
152
+ for ri, render_image in enumerate(render_images[0]):
153
+ render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
154
+ save_video(
155
+ render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
156
+ )
157
+ timer.end("Rendering")
158
+
159
+ timer.start("Exporting mesh")
160
+ meshes = model.extract_mesh(scene_codes, resolution=args.mc_resolution)
161
+ meshes[0].export(os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}"))
162
+ timer.end("Exporting mesh")
tsr/models/isosurface.py CHANGED
@@ -7,7 +7,7 @@ from torchmcubes import marching_cubes
7
 
8
 
9
  class IsosurfaceHelper(nn.Module):
10
- points_range: Tuple[float, float] = (-1, 1)
11
 
12
  @property
13
  def grid_vertices(self) -> torch.FloatTensor:
@@ -41,8 +41,12 @@ class MarchingCubeHelper(IsosurfaceHelper):
41
  self,
42
  level: torch.FloatTensor,
43
  ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
44
- level = level.view(self.resolution, self.resolution, self.resolution)
45
- v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
 
 
 
 
46
  v_pos = v_pos[..., [2, 1, 0]]
47
- v_pos = v_pos * 2.0 / (self.resolution - 1.0) - 1.0
48
  return v_pos.to(level.device), t_pos_idx.to(level.device)
 
7
 
8
 
9
  class IsosurfaceHelper(nn.Module):
10
+ points_range: Tuple[float, float] = (0, 1)
11
 
12
  @property
13
  def grid_vertices(self) -> torch.FloatTensor:
 
41
  self,
42
  level: torch.FloatTensor,
43
  ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
44
+ level = -level.view(self.resolution, self.resolution, self.resolution)
45
+ try:
46
+ v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
47
+ except AttributeError:
48
+ print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
49
+ v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
50
  v_pos = v_pos[..., [2, 1, 0]]
51
+ v_pos = v_pos / (self.resolution - 1.0)
52
  return v_pos.to(level.device), t_pos_idx.to(level.device)
tsr/models/nerf_renderer.py CHANGED
@@ -1,9 +1,10 @@
1
- from dataclasses import dataclass, field
2
  from typing import Dict, Optional
3
 
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange, reduce
 
7
 
8
  from ..utils import (
9
  BaseModule,
@@ -37,73 +38,79 @@ class TriplaneNeRFRenderer(BaseModule):
37
  chunk_size >= 0
38
  ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
  self.chunk_size = chunk_size
40
- def make_step_grid(self,device, resolution: int, chunk_size: int = 32):
41
- coords = torch.linspace(-1.0, 1.0, resolution, device = device)
42
- x, y, z = torch.meshgrid(coords[0:chunk_size], coords, coords, indexing="ij")
43
- x = x.reshape(-1, 1)
44
- y = y.reshape(-1, 1)
45
- z = z.reshape(-1, 1)
46
- verts = torch.cat([x, y, z], dim = -1).view(-1, 3)
47
- indices2D: torch.Tensor = torch.stack(
48
- (verts[..., [0, 1]], verts[..., [0, 2]], verts[..., [1, 2]]),
49
- dim=-3,
50
- )
51
- return indices2D
52
 
53
- def query_triplane_volume_density(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, sample_count: int = 1024 * 1024 * 4) -> torch.Tensor:
54
- layer_count = sample_count // (resolution * resolution)
55
- out_list = self.do_query_triplane_volume_density(decoder, triplane, resolution, layer_count)
56
- return get_activation(self.cfg.density_activation)(
57
- out_list.view([resolution * resolution * resolution, 1]) + self.cfg.density_bias
58
- )
59
- def do_query_triplane_volume_density(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, layer_count: int) -> torch.Tensor:
60
- step = 2.0 * layer_count / (resolution - 1)
61
- indices2D = self.make_step_grid(triplane.device, resolution, layer_count)
62
-
63
- out_list = torch.zeros([resolution, resolution * resolution, 1], device = triplane.device
64
- )
65
- for i in range(0, resolution, layer_count):
66
- if i + layer_count > resolution:
67
- layer_count = resolution - i
68
- indices2D = indices2D[..., :resolution * resolution * layer_count, :]
69
- density_step = self.sample_step_triplane_volume_density(decoder, triplane, indices2D)
70
- # todo directly march cube
71
- out_list[i:i + layer_count] = density_step.view([layer_count, resolution * resolution, 1])
72
- #out_list.append(net_out['density'])
73
- indices2D.transpose(1, 2)[0, 0] += step
74
- indices2D.transpose(1, 2)[1, 0] += step
75
-
76
- return out_list
77
- def sample_step_triplane_volume_density(self, decoder, triplane, indices2D):
78
- out: torch.Tensor = F.grid_sample(
79
- rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
80
- rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
81
- align_corners=False,
82
- mode="bilinear",
83
- )
84
- if self.cfg.feature_reduction == "concat":
85
- out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
86
- elif self.cfg.feature_reduction == "mean":
87
- out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
88
- else:
89
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- net_out: Dict[str, torch.Tensor] = decoder(out)
92
- return net_out['density']
93
  def query_triplane(
94
  self,
95
  decoder: torch.nn.Module,
96
  positions: torch.Tensor,
97
  triplane: torch.Tensor,
 
98
  ) -> Dict[str, torch.Tensor]:
99
  input_shape = positions.shape[:-1]
100
  positions = positions.view(-1, 3)
101
 
102
  # positions in (-radius, radius)
103
  # normalized to (-1, 1) for grid sample
104
- #positions = scale_tensor(
105
- # positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
106
- #)
 
107
 
108
  def _query_chunk(x):
109
  indices2D: torch.Tensor = torch.stack(
 
1
+ from dataclasses import dataclass
2
  from typing import Dict, Optional
3
 
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange, reduce
7
+ from torchmcubes import marching_cubes
8
 
9
  from ..utils import (
10
  BaseModule,
 
38
  chunk_size >= 0
39
  ), "chunk_size must be a non-negative integer (0 for no chunking)."
40
  self.chunk_size = chunk_size
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def interpolate_triplane(self, triplane: torch.Tensor, resolution: int):
43
+ coords = torch.linspace(-1.0, 1.0, resolution, device = triplane.device)
44
+ x, y = torch.meshgrid(coords, coords, indexing="ij")
45
+ verts2D = torch.cat([x.view(resolution, resolution,1), y.view(resolution, resolution,1)], dim = -1)
46
+ verts2D = verts2D.expand(3, -1, -1, -1)
47
+ return F.grid_sample(triplane, verts2D, align_corners=False,mode="bilinear") # [3 40 H W] xy xz yz
48
+
49
+ def block_based_marchingcube(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, threshold, block_resolution = 128) -> torch.Tensor:
50
+ resolution += 1 # sample 1 more line of density, 1024 + 1 == 1025, 0 mapping to -1.0f, 512 mapping to 0.0f, 1025 mapping to 1.0f, for better floating point precision.
51
+ block_size = 2.0 * block_resolution / (resolution - 1)
52
+ voxel_size = block_size / block_resolution
53
+ interpolated = self.interpolate_triplane(triplane, resolution)
54
+
55
+ pos_list = []
56
+ indices_list = []
57
+ for x in range(0, resolution - 1, block_resolution):
58
+ size_x = resolution - x if x + block_resolution >= resolution else block_resolution + 1 # sample 1 more line of density, so marching cubes resolution match block_resolution
59
+ for y in range(0, resolution - 1, block_resolution):
60
+ size_y = resolution - y if y + block_resolution >= resolution else block_resolution + 1
61
+ for z in range(0, resolution - 1, block_resolution):
62
+ size_z = resolution - z if z + block_resolution >= resolution else block_resolution + 1
63
+ xyplane = interpolated[0:1, :, x:x+size_x, y:y+size_y].expand(size_z, -1, -1, -1, -1).permute(3, 4, 0, 1, 2)
64
+ xzplane = interpolated[1:2, :, x:x+size_x, z:z+size_z].expand(size_y, -1, -1, -1, -1).permute(3, 0, 4, 1, 2)
65
+ yzplane = interpolated[2:3, :, y:y+size_y, z:z+size_z].expand(size_x, -1, -1, -1, -1).permute(0, 3, 4, 1, 2)
66
+ sz = size_x * size_y * size_z
67
+ out = torch.cat([xyplane, xzplane, yzplane], dim=3).view(sz, 3, -1)
68
+
69
+ if self.cfg.feature_reduction == "concat":
70
+ out = out.view(sz, -1)
71
+ elif self.cfg.feature_reduction == "mean":
72
+ out = reduce(out, "N Np Cp -> N Cp", Np=3, reduction="mean")
73
+ else:
74
+ raise NotImplementedError
75
+ net_out = decoder(out)
76
+ out = None # discard samples
77
+ density = net_out["density"]
78
+ net_out = None # discard colors
79
+ density = get_activation(self.cfg.density_activation)(density + self.cfg.density_bias).view(size_x, size_y, size_z)
80
+ try: # now do the marching cube
81
+ v_pos, indices = marching_cubes(density.detach(), threshold)
82
+ except AttributeError:
83
+ print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
84
+ v_pos, indices = self.mc_func(density.detach().cpu(), 0.0)
85
+ offset = torch.tensor([x * voxel_size - 1.0, y * voxel_size - 1.0, z * voxel_size - 1.0], device = triplane.device)
86
+ v_pos = v_pos[..., [2, 1, 0]] * voxel_size + offset
87
+
88
+ indices_list.append(indices)
89
+ pos_list.append(v_pos)
90
+
91
+ vertex_count = 0
92
+ for i in range(0, len(pos_list)):
93
+ indices_list[i] += vertex_count
94
+ vertex_count += pos_list[i].size(0)
95
+
96
+ return torch.cat(pos_list), torch.cat(indices_list)
97
 
 
 
98
  def query_triplane(
99
  self,
100
  decoder: torch.nn.Module,
101
  positions: torch.Tensor,
102
  triplane: torch.Tensor,
103
+ scale_pos = True
104
  ) -> Dict[str, torch.Tensor]:
105
  input_shape = positions.shape[:-1]
106
  positions = positions.view(-1, 3)
107
 
108
  # positions in (-radius, radius)
109
  # normalized to (-1, 1) for grid sample
110
+ if scale_pos:
111
+ positions = scale_tensor(
112
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
113
+ )
114
 
115
  def _query_chunk(x):
116
  indices2D: torch.Tensor = torch.stack(
tsr/models/network_utils.py CHANGED
@@ -1,4 +1,4 @@
1
- from dataclasses import dataclass, field
2
  from typing import Optional
3
 
4
  import torch
 
1
+ from dataclasses import dataclass
2
  from typing import Optional
3
 
4
  import torch
tsr/models/tokenizers/image.py CHANGED
@@ -1,5 +1,4 @@
1
  from dataclasses import dataclass
2
- from typing import Optional
3
 
4
  import torch
5
  import torch.nn as nn
 
1
  from dataclasses import dataclass
 
2
 
3
  import torch
4
  import torch.nn as nn
tsr/models/transformer/attention.py CHANGED
@@ -11,6 +11,31 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from typing import Optional
15
 
16
  import torch
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
  from typing import Optional
40
 
41
  import torch
tsr/models/transformer/basic_transformer_block.py CHANGED
@@ -11,8 +11,32 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- from typing import Any, Dict, Optional
16
 
17
  import torch
18
  import torch.nn.functional as F
@@ -32,8 +56,6 @@ class BasicTransformerBlock(nn.Module):
32
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
33
  cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
34
  activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
35
- num_embeds_ada_norm (:
36
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
37
  attention_bias (:
38
  obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
39
  only_cross_attention (`bool`, *optional*):
@@ -48,8 +70,6 @@ class BasicTransformerBlock(nn.Module):
48
  The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
49
  final_dropout (`bool` *optional*, defaults to False):
50
  Whether to apply a final dropout after the last feed-forward layer.
51
- attention_type (`str`, *optional*, defaults to `"default"`):
52
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
53
  """
54
 
55
  def __init__(
@@ -95,9 +115,9 @@ class BasicTransformerBlock(nn.Module):
95
 
96
  self.attn2 = Attention(
97
  query_dim=dim,
98
- cross_attention_dim=cross_attention_dim
99
- if not double_self_attention
100
- else None,
101
  heads=num_attention_heads,
102
  dim_head=attention_head_dim,
103
  dropout=dropout,
@@ -139,9 +159,9 @@ class BasicTransformerBlock(nn.Module):
139
 
140
  attn_output = self.attn1(
141
  norm_hidden_states,
142
- encoder_hidden_states=encoder_hidden_states
143
- if self.only_cross_attention
144
- else None,
145
  attention_mask=attention_mask,
146
  )
147
 
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
 
39
+ from typing import Optional
40
 
41
  import torch
42
  import torch.nn.functional as F
 
56
  dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
57
  cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
58
  activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
 
 
59
  attention_bias (:
60
  obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
61
  only_cross_attention (`bool`, *optional*):
 
70
  The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
71
  final_dropout (`bool` *optional*, defaults to False):
72
  Whether to apply a final dropout after the last feed-forward layer.
 
 
73
  """
74
 
75
  def __init__(
 
115
 
116
  self.attn2 = Attention(
117
  query_dim=dim,
118
+ cross_attention_dim=(
119
+ cross_attention_dim if not double_self_attention else None
120
+ ),
121
  heads=num_attention_heads,
122
  dim_head=attention_head_dim,
123
  dropout=dropout,
 
159
 
160
  attn_output = self.attn1(
161
  norm_hidden_states,
162
+ encoder_hidden_states=(
163
+ encoder_hidden_states if self.only_cross_attention else None
164
+ ),
165
  attention_mask=attention_mask,
166
  )
167
 
tsr/models/transformer/transformer_1d.py CHANGED
@@ -1,5 +1,43 @@
1
- from dataclasses import dataclass, field
2
- from typing import Any, Dict, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import torch
5
  import torch.nn.functional as F
@@ -10,28 +48,6 @@ from .basic_transformer_block import BasicTransformerBlock
10
 
11
 
12
  class Transformer1D(BaseModule):
13
- """
14
- A 1D Transformer model for sequence data.
15
-
16
- Parameters:
17
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
18
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
19
- in_channels (`int`, *optional*):
20
- The number of channels in the input and output (specify if the input is **continuous**).
21
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
22
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
23
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
24
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
25
- num_embeds_ada_norm ( `int`, *optional*):
26
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
27
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
28
- added to the hidden states.
29
-
30
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
31
- attention_bias (`bool`, *optional*):
32
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
33
- """
34
-
35
  @dataclass
36
  class Config(BaseModule.Config):
37
  num_attention_heads: int = 16
@@ -119,15 +135,6 @@ class Transformer1D(BaseModule):
119
  encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
120
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
121
  self-attention.
122
- timestep ( `torch.LongTensor`, *optional*):
123
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
124
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
125
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
126
- `AdaLayerZeroNorm`.
127
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
128
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
129
- `self.processor` in
130
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
131
  attention_mask ( `torch.Tensor`, *optional*):
132
  An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
133
  is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
@@ -140,13 +147,9 @@ class Transformer1D(BaseModule):
140
 
141
  If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
142
  above. This bias will be added to the cross-attention scores.
143
- return_dict (`bool`, *optional*, defaults to `True`):
144
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
145
- tuple.
146
 
147
  Returns:
148
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
149
- `tuple` where the first element is the sample tensor.
150
  """
151
  # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
152
  # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from dataclasses import dataclass
40
+ from typing import Optional
41
 
42
  import torch
43
  import torch.nn.functional as F
 
48
 
49
 
50
  class Transformer1D(BaseModule):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @dataclass
52
  class Config(BaseModule.Config):
53
  num_attention_heads: int = 16
 
135
  encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
136
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
137
  self-attention.
 
 
 
 
 
 
 
 
 
138
  attention_mask ( `torch.Tensor`, *optional*):
139
  An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
140
  is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
 
147
 
148
  If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
149
  above. This bias will be added to the cross-attention scores.
 
 
 
150
 
151
  Returns:
152
+ torch.FloatTensor
 
153
  """
154
  # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
155
  # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
tsr/system.py CHANGED
@@ -13,7 +13,6 @@ from huggingface_hub import hf_hub_download
13
  from omegaconf import OmegaConf
14
  from PIL import Image
15
 
16
- from .models.isosurface import MarchingCubeHelper
17
  from .utils import (
18
  BaseModule,
19
  ImagePreprocessor,
@@ -50,17 +49,17 @@ class TSR(BaseModule):
50
 
51
  @classmethod
52
  def from_pretrained(
53
- cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str, token=None
54
  ):
55
  if os.path.isdir(pretrained_model_name_or_path):
56
  config_path = os.path.join(pretrained_model_name_or_path, config_name)
57
  weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
58
  else:
59
  config_path = hf_hub_download(
60
- repo_id=pretrained_model_name_or_path, filename=config_name, token=token
61
  )
62
  weight_path = hf_hub_download(
63
- repo_id=pretrained_model_name_or_path, filename=weight_name, token=token
64
  )
65
 
66
  cfg = OmegaConf.load(config_path)
@@ -160,36 +159,20 @@ class TSR(BaseModule):
160
 
161
  return images
162
 
163
- def set_marching_cubes_resolution(self, resolution: int):
164
- if (
165
- self.isosurface_helper is not None
166
- and self.isosurface_helper.resolution == resolution
167
- ):
168
- return
169
- self.isosurface_helper = MarchingCubeHelper(resolution)
170
-
171
  def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
172
- self.set_marching_cubes_resolution(resolution)
173
  meshes = []
174
  for scene_code in scene_codes:
175
  with torch.no_grad():
176
- density = self.renderer.query_triplane_volume_density(
177
- self.decoder.to(scene_codes.device),
178
- scene_code,
179
- resolution
180
- ) - threshold
181
- v_pos, t_pos_idx = self.isosurface_helper(density)
182
- density = None
183
- v_pos = v_pos.to(scene_codes.device)
184
- color = self.renderer.query_triplane(
185
- self.decoder.to(scene_codes.device),
186
- v_pos,
187
  scene_code,
188
- )["color"]
 
 
 
189
  v_pos = scale_tensor(
190
  v_pos,
191
- self.isosurface_helper.points_range,
192
- (-self.renderer.cfg.radius, self.renderer.cfg.radius),
193
  )
194
  mesh = trimesh.Trimesh(
195
  vertices=v_pos.cpu().numpy(),
 
13
  from omegaconf import OmegaConf
14
  from PIL import Image
15
 
 
16
  from .utils import (
17
  BaseModule,
18
  ImagePreprocessor,
 
49
 
50
  @classmethod
51
  def from_pretrained(
52
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
53
  ):
54
  if os.path.isdir(pretrained_model_name_or_path):
55
  config_path = os.path.join(pretrained_model_name_or_path, config_name)
56
  weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
57
  else:
58
  config_path = hf_hub_download(
59
+ repo_id=pretrained_model_name_or_path, filename=config_name
60
  )
61
  weight_path = hf_hub_download(
62
+ repo_id=pretrained_model_name_or_path, filename=weight_name
63
  )
64
 
65
  cfg = OmegaConf.load(config_path)
 
159
 
160
  return images
161
 
 
 
 
 
 
 
 
 
162
  def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
 
163
  meshes = []
164
  for scene_code in scene_codes:
165
  with torch.no_grad():
166
+ v_pos, t_pos_idx = self.renderer.block_based_marchingcube(self.decoder.to(scene_codes.device),
 
 
 
 
 
 
 
 
 
 
167
  scene_code,
168
+ resolution,
169
+ threshold
170
+ )
171
+ color = self.renderer.query_triplane(self.decoder.to(scene_codes.device), v_pos.to(scene_codes.device), scene_code, False)["color"]
172
  v_pos = scale_tensor(
173
  v_pos,
174
+ (-1.0, 1.0),
175
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius)
176
  )
177
  mesh = trimesh.Trimesh(
178
  vertices=v_pos.cpu().numpy(),
tsr/utils.py CHANGED
@@ -300,7 +300,6 @@ def get_rays(
300
  directions,
301
  c2w,
302
  keepdim=False,
303
- noise_scale=0.0,
304
  normalize=False,
305
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
306
  # Rotate ray directions from camera coordinate to the world coordinate
@@ -331,12 +330,6 @@ def get_rays(
331
  ) # (B, H, W, 3)
332
  rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
333
 
334
- # add camera noise to avoid grid-like artifect
335
- # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
336
- if noise_scale > 0:
337
- rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
338
- rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
339
-
340
  if normalize:
341
  rays_d = F.normalize(rays_d, dim=-1)
342
  if not keepdim:
@@ -477,6 +470,5 @@ def save_video(
477
 
478
  def to_gradio_3d_orientation(mesh):
479
  mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
480
- mesh.apply_scale([1, 1, -1])
481
  mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
482
  return mesh
 
300
  directions,
301
  c2w,
302
  keepdim=False,
 
303
  normalize=False,
304
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
305
  # Rotate ray directions from camera coordinate to the world coordinate
 
330
  ) # (B, H, W, 3)
331
  rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
332
 
 
 
 
 
 
 
333
  if normalize:
334
  rays_d = F.normalize(rays_d, dim=-1)
335
  if not keepdim:
 
470
 
471
  def to_gradio_3d_orientation(mesh):
472
  mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
 
473
  mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
474
  return mesh
upload.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ api = HfApi()
3
+
4
+ api.upload_folder(
5
+ folder_path="/workspaces/TripoSR",
6
+ repo_id="michaelj/TripoSR",
7
+ repo_type="space",
8
+ )