02alexander commited on
Commit
71d5bf5
·
1 Parent(s): 344c16f

copy code to this repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore CHANGED
@@ -20,3 +20,5 @@ __pycache__
20
  .mypy_cache
21
  .ruff_cache
22
  venv
 
 
 
20
  .mypy_cache
21
  .ruff_cache
22
  venv
23
+
24
+ shell.nix
CMakeLists.txt DELETED
@@ -1,18 +0,0 @@
1
- cmake_minimum_required(VERSION 3.16...3.27)
2
-
3
- project(PROJ_NAME LANGUAGES CXX)
4
-
5
- set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
6
-
7
- if(NOT DEFINED CMAKE_CXX_STANDARD)
8
- set(CMAKE_CXX_STANDARD 17)
9
- endif()
10
-
11
- # Rerun:
12
- include(FetchContent)
13
- FetchContent_Declare(rerun_sdk URL https://github.com/rerun-io/rerun/releases/download/0.15.1/rerun_cpp_sdk.zip)
14
- FetchContent_MakeAvailable(rerun_sdk)
15
-
16
- add_executable(PROJ_NAME src/main.cpp)
17
- target_link_libraries(PROJ_NAME rerun_sdk)
18
- target_include_directories(PROJ_NAME PRIVATE src)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Cargo.lock DELETED
@@ -1,7 +0,0 @@
1
- # This file is automatically @generated by Cargo.
2
- # It is not intended for manual editing.
3
- version = 3
4
-
5
- [[package]]
6
- name = "new_project_name"
7
- version = "0.1.0"
 
 
 
 
 
 
 
 
Cargo.toml DELETED
@@ -1,198 +0,0 @@
1
- [package]
2
- authors = ["rerun.io <opensource@rerun.io>"]
3
- categories = [] # TODO: fill in if you plan on publishing the crate
4
- description = "" # TODO: fill in if you plan on publishing the crate
5
- edition = "2021"
6
- homepage = "https://github.com/rerun-io/new_repo_name"
7
- include = ["LICENSE-APACHE", "LICENSE-MIT", "**/*.rs", "Cargo.toml"]
8
- keywords = [] # TODO: fill in if you plan on publishing the crate
9
- license = "MIT OR Apache-2.0"
10
- name = "new_project_name"
11
- publish = false # TODO: set to `true` if you plan on publishing the crate
12
- readme = "README.md"
13
- repository = "https://github.com/rerun-io/new_repo_name"
14
- rust-version = "1.76"
15
- version = "0.1.0"
16
-
17
- [package.metadata.docs.rs]
18
- all-features = true
19
- targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
20
-
21
-
22
- [features]
23
- default = []
24
-
25
-
26
- [dependencies]
27
-
28
-
29
- [dev-dependencies]
30
-
31
-
32
- [patch.crates-io]
33
-
34
-
35
- [lints]
36
- workspace = true
37
-
38
-
39
- [workspace.lints.rust]
40
- unsafe_code = "deny"
41
-
42
- elided_lifetimes_in_paths = "warn"
43
- future_incompatible = "warn"
44
- nonstandard_style = "warn"
45
- rust_2018_idioms = "warn"
46
- rust_2021_prelude_collisions = "warn"
47
- semicolon_in_expressions_from_macros = "warn"
48
- trivial_numeric_casts = "warn"
49
- unsafe_op_in_unsafe_fn = "warn" # `unsafe_op_in_unsafe_fn` may become the default in future Rust versions: https://github.com/rust-lang/rust/issues/71668
50
- unused_extern_crates = "warn"
51
- unused_import_braces = "warn"
52
- unused_lifetimes = "warn"
53
-
54
- trivial_casts = "allow"
55
- unused_qualifications = "allow"
56
-
57
- [workspace.lints.rustdoc]
58
- all = "warn"
59
- missing_crate_level_docs = "warn"
60
-
61
- # See also clippy.toml
62
- [workspace.lints.clippy]
63
- as_ptr_cast_mut = "warn"
64
- await_holding_lock = "warn"
65
- bool_to_int_with_if = "warn"
66
- char_lit_as_u8 = "warn"
67
- checked_conversions = "warn"
68
- clear_with_drain = "warn"
69
- cloned_instead_of_copied = "warn"
70
- dbg_macro = "warn"
71
- debug_assert_with_mut_call = "warn"
72
- derive_partial_eq_without_eq = "warn"
73
- disallowed_macros = "warn" # See clippy.toml
74
- disallowed_methods = "warn" # See clippy.toml
75
- disallowed_names = "warn" # See clippy.toml
76
- disallowed_script_idents = "warn" # See clippy.toml
77
- disallowed_types = "warn" # See clippy.toml
78
- doc_link_with_quotes = "warn"
79
- doc_markdown = "warn"
80
- empty_enum = "warn"
81
- enum_glob_use = "warn"
82
- equatable_if_let = "warn"
83
- exit = "warn"
84
- expl_impl_clone_on_copy = "warn"
85
- explicit_deref_methods = "warn"
86
- explicit_into_iter_loop = "warn"
87
- explicit_iter_loop = "warn"
88
- fallible_impl_from = "warn"
89
- filter_map_next = "warn"
90
- flat_map_option = "warn"
91
- float_cmp_const = "warn"
92
- fn_params_excessive_bools = "warn"
93
- fn_to_numeric_cast_any = "warn"
94
- from_iter_instead_of_collect = "warn"
95
- get_unwrap = "warn"
96
- if_let_mutex = "warn"
97
- implicit_clone = "warn"
98
- imprecise_flops = "warn"
99
- index_refutable_slice = "warn"
100
- inefficient_to_string = "warn"
101
- infinite_loop = "warn"
102
- into_iter_without_iter = "warn"
103
- invalid_upcast_comparisons = "warn"
104
- iter_not_returning_iterator = "warn"
105
- iter_on_empty_collections = "warn"
106
- iter_on_single_items = "warn"
107
- iter_over_hash_type = "warn"
108
- iter_without_into_iter = "warn"
109
- large_digit_groups = "warn"
110
- large_include_file = "warn"
111
- large_stack_arrays = "warn"
112
- large_stack_frames = "warn"
113
- large_types_passed_by_value = "warn"
114
- let_underscore_untyped = "warn"
115
- let_unit_value = "warn"
116
- linkedlist = "warn"
117
- lossy_float_literal = "warn"
118
- macro_use_imports = "warn"
119
- manual_assert = "warn"
120
- manual_clamp = "warn"
121
- manual_instant_elapsed = "warn"
122
- manual_let_else = "warn"
123
- manual_ok_or = "warn"
124
- manual_string_new = "warn"
125
- map_err_ignore = "warn"
126
- map_flatten = "warn"
127
- map_unwrap_or = "warn"
128
- match_on_vec_items = "warn"
129
- match_same_arms = "warn"
130
- match_wild_err_arm = "warn"
131
- match_wildcard_for_single_variants = "warn"
132
- mem_forget = "warn"
133
- mismatched_target_os = "warn"
134
- mismatching_type_param_order = "warn"
135
- missing_assert_message = "warn"
136
- missing_enforced_import_renames = "warn"
137
- missing_errors_doc = "warn"
138
- missing_safety_doc = "warn"
139
- mut_mut = "warn"
140
- mutex_integer = "warn"
141
- needless_borrow = "warn"
142
- needless_continue = "warn"
143
- needless_for_each = "warn"
144
- needless_pass_by_ref_mut = "warn"
145
- needless_pass_by_value = "warn"
146
- negative_feature_names = "warn"
147
- nonstandard_macro_braces = "warn"
148
- option_option = "warn"
149
- path_buf_push_overwrite = "warn"
150
- ptr_as_ptr = "warn"
151
- ptr_cast_constness = "warn"
152
- pub_without_shorthand = "warn"
153
- rc_mutex = "warn"
154
- readonly_write_lock = "warn"
155
- redundant_type_annotations = "warn"
156
- ref_option_ref = "warn"
157
- rest_pat_in_fully_bound_structs = "warn"
158
- same_functions_in_if_condition = "warn"
159
- semicolon_if_nothing_returned = "warn"
160
- should_panic_without_expect = "warn"
161
- significant_drop_tightening = "warn"
162
- single_match_else = "warn"
163
- str_to_string = "warn"
164
- string_add = "warn"
165
- string_add_assign = "warn"
166
- string_lit_as_bytes = "warn"
167
- string_lit_chars_any = "warn"
168
- string_to_string = "warn"
169
- suspicious_command_arg_space = "warn"
170
- suspicious_xor_used_as_pow = "warn"
171
- todo = "warn"
172
- too_many_lines = "warn"
173
- trailing_empty_array = "warn"
174
- trait_duplication_in_bounds = "warn"
175
- tuple_array_conversions = "warn"
176
- unchecked_duration_subtraction = "warn"
177
- undocumented_unsafe_blocks = "warn"
178
- unimplemented = "warn"
179
- uninhabited_references = "warn"
180
- uninlined_format_args = "warn"
181
- unnecessary_box_returns = "warn"
182
- unnecessary_safety_doc = "warn"
183
- unnecessary_struct_initialization = "warn"
184
- unnecessary_wraps = "warn"
185
- unnested_or_patterns = "warn"
186
- unused_peekable = "warn"
187
- unused_rounding = "warn"
188
- unused_self = "warn"
189
- unwrap_used = "warn"
190
- use_self = "warn"
191
- useless_transmute = "warn"
192
- verbose_file_reads = "warn"
193
- wildcard_dependencies = "warn"
194
- wildcard_imports = "warn"
195
- zero_sized_map_values = "warn"
196
-
197
- manual_range_contains = "allow" # this one is just worse imho
198
- ref_patterns = "allow" # It's nice to avoid ref pattern, but there are some situations that are hard (impossible?) to express without.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,40 +1,3 @@
1
- # Rerun template repository
2
- Template for our private and public repos, containing CI, CoC, etc
3
 
4
- When creating a new Rerun repository, use this as a template, then modify it as it makes sense.
5
-
6
- This template should be the default for any repository of any kind, including:
7
- * Rust projects
8
- * C++ projects
9
- * Python projects
10
- * Other stuff
11
-
12
- This template includes
13
- * License files
14
- * Code of Conduct
15
- * Helpers for checking and linting Rust code
16
- - `cargo-clippy`
17
- - `cargo-deny`
18
- - `rust-toolchain`
19
- - …
20
- * CI for:
21
- - Spell checking
22
- - Link checking
23
- - C++ checks
24
- - Python checks
25
- - Rust checks
26
-
27
-
28
- ## How to use
29
- Start by clicking "Use this template" at https://github.com/rerun-io/rerun_template/ or follow [these instructions](https://docs.github.com/en/free-pro-team@latest/github/creating-cloning-and-archiving-repositories/creating-a-repository-from-a-template).
30
-
31
- Then follow these steps:
32
- * Run `scripts/template_update.py init --languages cpp,rust,python` to delete files you don't need (give the languages you need support for)
33
- * Search and replace all instances of `new_repo_name` with the name of the repository.
34
- * Search and replace all instances of `new_project_name` with the name of the project (crate/binary name).
35
- * Search for `TODO` and fill in all those places
36
- * Replace this `README.md` with something better
37
- * Commit!
38
-
39
- In the future you can always update this repository with the latest changes from the template by running:
40
- * `scripts/template_update.py update --languages cpp,rust,python`
 
1
+ ## Fork of the [InstantMesh space]() but with [Rerun](https://www.rerun.io) for visualization
 
2
 
3
+ The resulting Huggingface space can be found [here.](https://huggingface.co/spaces/rerun/InstantMesh)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ import threading
6
+ from queue import SimpleQueue
7
+ from typing import Any
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import rembg
12
+ import rerun as rr
13
+ import rerun.blueprint as rrb
14
+ import spaces
15
+ import torch
16
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
17
+ from einops import rearrange
18
+ from gradio_rerun import Rerun
19
+ from huggingface_hub import hf_hub_download
20
+ from omegaconf import OmegaConf
21
+ from PIL import Image
22
+ from pytorch_lightning import seed_everything
23
+ from torchvision.transforms import v2
24
+
25
+ from src.models.lrm_mesh import InstantMesh
26
+ from src.utils.camera_util import (
27
+ FOV_to_intrinsics,
28
+ get_circular_camera_poses,
29
+ get_zero123plus_input_cameras,
30
+ )
31
+ from src.utils.infer_util import remove_background, resize_foreground
32
+ from src.utils.train_util import instantiate_from_config
33
+
34
+
35
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
36
+ """Get the rendering camera parameters."""
37
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
38
+ if is_flexicubes:
39
+ cameras = torch.linalg.inv(c2ws)
40
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
41
+ else:
42
+ extrinsics = c2ws.flatten(-2)
43
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
44
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
45
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
46
+ return cameras
47
+
48
+
49
+ ###############################################################################
50
+ # Configuration.
51
+ ###############################################################################
52
+
53
+
54
+ def find_cuda():
55
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
56
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
57
+
58
+ if cuda_home and os.path.exists(cuda_home):
59
+ return cuda_home
60
+
61
+ # Search for the nvcc executable in the system's PATH
62
+ nvcc_path = shutil.which("nvcc")
63
+
64
+ if nvcc_path:
65
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
66
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
67
+ return cuda_path
68
+
69
+ return None
70
+
71
+
72
+ cuda_path = find_cuda()
73
+
74
+ if cuda_path:
75
+ print(f"CUDA installation found at: {cuda_path}")
76
+ else:
77
+ print("CUDA installation not found")
78
+
79
+ config_path = "configs/instant-mesh-large.yaml"
80
+ config = OmegaConf.load(config_path)
81
+ config_name = os.path.basename(config_path).replace(".yaml", "")
82
+ model_config = config.model_config
83
+ infer_config = config.infer_config
84
+
85
+ IS_FLEXICUBES = True if config_name.startswith("instant-mesh") else False
86
+
87
+ device = torch.device("cuda")
88
+
89
+ # load diffusion model
90
+ print("Loading diffusion model ...")
91
+ pipeline = DiffusionPipeline.from_pretrained(
92
+ "sudo-ai/zero123plus-v1.2",
93
+ custom_pipeline="zero123plus",
94
+ torch_dtype=torch.float16,
95
+ )
96
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
97
+
98
+ # load custom white-background UNet
99
+ unet_ckpt_path = hf_hub_download(
100
+ repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model"
101
+ )
102
+ state_dict = torch.load(unet_ckpt_path, map_location="cpu")
103
+ pipeline.unet.load_state_dict(state_dict, strict=True)
104
+
105
+ pipeline = pipeline.to(device)
106
+ print(f"type(pipeline)={type(pipeline)}")
107
+
108
+ # load reconstruction model
109
+ print("Loading reconstruction model ...")
110
+ model_ckpt_path = hf_hub_download(
111
+ repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model"
112
+ )
113
+ model: InstantMesh = instantiate_from_config(model_config)
114
+ state_dict = torch.load(model_ckpt_path, map_location="cpu")["state_dict"]
115
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith("lrm_generator.") and "source_camera" not in k}
116
+ model.load_state_dict(state_dict, strict=True)
117
+
118
+ model = model.to(device)
119
+
120
+ print("Loading Finished!")
121
+
122
+
123
+ def check_input_image(input_image):
124
+ if input_image is None:
125
+ raise gr.Error("No image uploaded!")
126
+
127
+
128
+ def preprocess(input_image, do_remove_background):
129
+ rembg_session = rembg.new_session() if do_remove_background else None
130
+
131
+ if do_remove_background:
132
+ input_image = remove_background(input_image, rembg_session)
133
+ input_image = resize_foreground(input_image, 0.85)
134
+
135
+ return input_image
136
+
137
+
138
+ def pipeline_callback(
139
+ log_queue: SimpleQueue, pipe: Any, step_index: int, timestep: float, callback_kwargs: dict[str, Any]
140
+ ) -> dict[str, Any]:
141
+ latents = callback_kwargs["latents"]
142
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] # type: ignore[attr-defined]
143
+ image = pipe.image_processor.postprocess(image, output_type="np").squeeze() # type: ignore[attr-defined]
144
+
145
+ log_queue.put(("mvs", rr.Image(image)))
146
+ log_queue.put(("latents", rr.Tensor(latents.squeeze())))
147
+
148
+ return callback_kwargs
149
+
150
+
151
+ def generate_mvs(log_queue, input_image, sample_steps, sample_seed):
152
+ seed_everything(sample_seed)
153
+
154
+ return pipeline(
155
+ input_image,
156
+ num_inference_steps=sample_steps,
157
+ callback_on_step_end=lambda *args, **kwargs: pipeline_callback(log_queue, *args, **kwargs),
158
+ ).images[0]
159
+
160
+
161
+ def make3d(log_queue, images: Image.Image):
162
+ global model
163
+ if IS_FLEXICUBES:
164
+ model.init_flexicubes_geometry(device, use_renderer=False)
165
+ model = model.eval()
166
+
167
+ images = np.asarray(images, dtype=np.float32) / 255.0
168
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
169
+ images = rearrange(images, "c (n h) (m w) -> (n m) c h w", n=3, m=2) # (6, 3, 320, 320)
170
+
171
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
172
+
173
+ images = images.unsqueeze(0).to(device)
174
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
175
+
176
+ with torch.no_grad():
177
+ # get triplane
178
+ planes = model.forward_planes(images, input_cameras)
179
+
180
+ # get mesh
181
+ mesh_out = model.extract_mesh(
182
+ planes,
183
+ use_texture_map=False,
184
+ **infer_config,
185
+ )
186
+
187
+ vertices, faces, vertex_colors = mesh_out
188
+
189
+ log_queue.put((
190
+ "mesh",
191
+ rr.Mesh3D(vertex_positions=vertices, vertex_colors=vertex_colors, triangle_indices=faces),
192
+ ))
193
+
194
+ return mesh_out
195
+
196
+
197
+ def generate_blueprint() -> rrb.Blueprint:
198
+ return rrb.Blueprint(
199
+ rrb.Horizontal(
200
+ rrb.Spatial3DView(origin="mesh"),
201
+ rrb.Grid(
202
+ rrb.Spatial2DView(origin="z123image"),
203
+ rrb.Spatial2DView(origin="preprocessed_image"),
204
+ rrb.Spatial2DView(origin="mvs"),
205
+ rrb.TensorView(
206
+ origin="latents",
207
+ ),
208
+ ),
209
+ column_shares=[1, 1],
210
+ ),
211
+ collapse_panels=True,
212
+ )
213
+
214
+
215
+ def compute(log_queue, input_image, do_remove_background, sample_steps, sample_seed):
216
+ preprocessed_image = preprocess(input_image, do_remove_background)
217
+ log_queue.put(("preprocessed_image", rr.Image(preprocessed_image)))
218
+
219
+ z123_image = generate_mvs(log_queue, preprocessed_image, sample_steps, sample_seed)
220
+ log_queue.put(("z123image", rr.Image(z123_image)))
221
+
222
+ _mesh_out = make3d(log_queue, z123_image)
223
+
224
+ log_queue.put("done")
225
+
226
+
227
+ @spaces.GPU
228
+ @rr.thread_local_stream("InstantMesh")
229
+ def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
230
+ log_queue = SimpleQueue()
231
+
232
+ stream = rr.binary_stream()
233
+
234
+ blueprint = generate_blueprint()
235
+ rr.send_blueprint(blueprint)
236
+ yield stream.read()
237
+
238
+ handle = threading.Thread(
239
+ target=compute, args=[log_queue, input_image, do_remove_background, sample_steps, sample_seed]
240
+ )
241
+ handle.start()
242
+ while True:
243
+ msg = log_queue.get()
244
+ if msg == "done":
245
+ break
246
+ else:
247
+ entity_path, entity = msg
248
+ rr.log(entity_path, entity)
249
+ yield stream.read()
250
+ handle.join()
251
+
252
+
253
+ _HEADER_ = """
254
+ <h2><b>Duplicate of the <a href='https://huggingface.co/spaces/TencentARC/InstantMesh'>InstantMesh space</a> that uses <a href='https://rerun.io/'>Rerun</a> for visualization.</b></h2>
255
+ <h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
256
+
257
+ **InstantMesh** is a feed-forward framework for efficient 3D mesh generation from a single image based on the LRM/Instant3D architecture.
258
+
259
+ Technical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
260
+ Source code: <a href='https://github.com/rerun-io/hf-example-instant-mesh'>Github</a>.
261
+ """
262
+
263
+ with gr.Blocks() as demo:
264
+ gr.Markdown(_HEADER_)
265
+ with gr.Row(variant="panel"):
266
+ with gr.Column(scale=1):
267
+ with gr.Row():
268
+ input_image = gr.Image(
269
+ label="Input Image",
270
+ image_mode="RGBA",
271
+ sources="upload",
272
+ # width=256,
273
+ # height=256,
274
+ type="pil",
275
+ elem_id="content_image",
276
+ )
277
+ with gr.Row():
278
+ with gr.Group():
279
+ do_remove_background = gr.Checkbox(label="Remove Background", value=True)
280
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
281
+
282
+ sample_steps = gr.Slider(label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
283
+
284
+ with gr.Row():
285
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
286
+
287
+ with gr.Row(variant="panel"):
288
+ gr.Examples(
289
+ examples=[os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))],
290
+ inputs=[input_image],
291
+ label="Examples",
292
+ cache_examples=False,
293
+ examples_per_page=16,
294
+ )
295
+
296
+ with gr.Column(scale=2):
297
+ viewer = Rerun(streaming=True, height=800)
298
+
299
+ with gr.Row():
300
+ gr.Markdown("""Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).""")
301
+
302
+ mv_images = gr.State()
303
+
304
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
305
+ fn=log_to_rr, inputs=[input_image, do_remove_background, sample_steps, sample_seed], outputs=[viewer]
306
+ )
307
+
308
+ demo.launch()
configs/instant-mesh-base.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_base.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-mesh-large.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm_mesh.InstantMesh
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/instant_mesh_large.ckpt
21
+ texture_resolution: 1024
22
+ render_resolution: 512
configs/instant-nerf-base.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 12
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 40
13
+ rendering_samples_per_ray: 96
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_base.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
configs/instant-nerf-large.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: src.models.lrm.InstantNeRF
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+
15
+
16
+ infer_config:
17
+ unet_path: ckpts/diffusion_pytorch_model.bin
18
+ model_path: ckpts/instant_nerf_large.ckpt
19
+ mesh_threshold: 10.0
20
+ mesh_resolution: 256
21
+ render_resolution: 384
examples/bird.jpg ADDED
examples/bubble_mart_blue.png ADDED
examples/cake.jpg ADDED
examples/cartoon_dinosaur.png ADDED
examples/chair_armed.png ADDED
examples/chair_comfort.jpg ADDED
examples/chair_wood.jpg ADDED
examples/chest.jpg ADDED
examples/cute_horse.jpg ADDED
examples/cute_tiger.jpg ADDED
examples/earphone.jpg ADDED
examples/fox.jpg ADDED
examples/fruit.jpg ADDED
examples/fruit_elephant.jpg ADDED
examples/genshin_building.png ADDED
examples/genshin_teapot.png ADDED
examples/hatsune_miku.png ADDED
examples/house2.jpg ADDED
examples/mushroom_teapot.jpg ADDED
examples/pikachu.png ADDED
examples/plant.jpg ADDED
examples/robot.jpg ADDED
examples/sea_turtle.png ADDED
examples/skating_shoe.jpg ADDED
examples/sorting_board.png ADDED
examples/sword.png ADDED
examples/toy_car.jpg ADDED
examples/watermelon.png ADDED
examples/whitedog.png ADDED
examples/x_teapot.jpg ADDED
examples/x_toyduck.jpg ADDED
main.py DELETED
@@ -1,11 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- from __future__ import annotations
4
-
5
-
6
- def main() -> None:
7
- pass
8
-
9
-
10
- if __name__ == "__main__":
11
- main()
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1 +1,27 @@
1
- rerun-sdk>=0.15.0,<0.16.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ torch==2.1.0
3
+ torchvision==0.16.0
4
+ torchaudio==2.1.0
5
+ pytorch-lightning==2.1.2
6
+ einops
7
+ omegaconf
8
+ deepspeed
9
+ torchmetrics
10
+ webdataset
11
+ accelerate
12
+ tensorboard
13
+ PyMCubes
14
+ trimesh
15
+ rembg
16
+ transformers
17
+ diffusers==0.28.2
18
+ bitsandbytes
19
+ imageio[ffmpeg]
20
+ xatlas
21
+ plyfile
22
+ xformers==0.0.22.post7
23
+ git+https://github.com/NVlabs/nvdiffrast/
24
+ huggingface-hub
25
+ gradio_client >= 0.12
26
+ rerun-sdk>=0.16.0,<0.17.0
27
+ gradio_rerun
src/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
File without changes
src/data/objaverse.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import pytorch_lightning as pl
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import webdataset as wds
14
+ from PIL import Image
15
+ from torch.utils.data import Dataset
16
+ from torch.utils.data.distributed import DistributedSampler
17
+
18
+ from src.utils.camera_util import (
19
+ FOV_to_intrinsics,
20
+ center_looking_at_camera_pose,
21
+ get_surrounding_views,
22
+ )
23
+ from src.utils.train_util import instantiate_from_config
24
+
25
+
26
+ class DataModuleFromConfig(pl.LightningDataModule):
27
+ def __init__(
28
+ self,
29
+ batch_size=8,
30
+ num_workers=4,
31
+ train=None,
32
+ validation=None,
33
+ test=None,
34
+ **kwargs,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.batch_size = batch_size
39
+ self.num_workers = num_workers
40
+
41
+ self.dataset_configs = dict()
42
+ if train is not None:
43
+ self.dataset_configs['train'] = train
44
+ if validation is not None:
45
+ self.dataset_configs['validation'] = validation
46
+ if test is not None:
47
+ self.dataset_configs['test'] = test
48
+
49
+ def setup(self, stage):
50
+
51
+ if stage in ['fit']:
52
+ self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
53
+ else:
54
+ raise NotImplementedError
55
+
56
+ def train_dataloader(self):
57
+
58
+ sampler = DistributedSampler(self.datasets['train'])
59
+ return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
60
+
61
+ def val_dataloader(self):
62
+
63
+ sampler = DistributedSampler(self.datasets['validation'])
64
+ return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
65
+
66
+ def test_dataloader(self):
67
+
68
+ return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
69
+
70
+
71
+ class ObjaverseData(Dataset):
72
+ def __init__(self,
73
+ root_dir='objaverse/',
74
+ meta_fname='valid_paths.json',
75
+ input_image_dir='rendering_random_32views',
76
+ target_image_dir='rendering_random_32views',
77
+ input_view_num=6,
78
+ target_view_num=2,
79
+ total_view_n=32,
80
+ fov=50,
81
+ camera_rotation=True,
82
+ validation=False,
83
+ ):
84
+ self.root_dir = Path(root_dir)
85
+ self.input_image_dir = input_image_dir
86
+ self.target_image_dir = target_image_dir
87
+
88
+ self.input_view_num = input_view_num
89
+ self.target_view_num = target_view_num
90
+ self.total_view_n = total_view_n
91
+ self.fov = fov
92
+ self.camera_rotation = camera_rotation
93
+
94
+ with open(os.path.join(root_dir, meta_fname)) as f:
95
+ filtered_dict = json.load(f)
96
+ paths = filtered_dict['good_objs']
97
+ self.paths = paths
98
+
99
+ self.depth_scale = 4.0
100
+
101
+ len(self.paths)
102
+ print('============= length of dataset %d =============' % len(self.paths))
103
+
104
+ def __len__(self):
105
+ return len(self.paths)
106
+
107
+ def load_im(self, path, color):
108
+ """Replace background pixel with random color in rendering."""
109
+ pil_img = Image.open(path)
110
+
111
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
112
+ alpha = image[:, :, 3:]
113
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
114
+
115
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
116
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
117
+ return image, alpha
118
+
119
+ def __getitem__(self, index):
120
+ # load data
121
+ while True:
122
+ input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
123
+ target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
124
+
125
+ indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
126
+ input_indices = indices[:self.input_view_num]
127
+ target_indices = indices[self.input_view_num:]
128
+
129
+ '''background color, default: white'''
130
+ bg_white = [1., 1., 1.]
131
+ bg_black = [0., 0., 0.]
132
+
133
+ image_list = []
134
+ alpha_list = []
135
+ depth_list = []
136
+ normal_list = []
137
+ pose_list = []
138
+
139
+ try:
140
+ input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
141
+ for idx in input_indices:
142
+ image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
143
+ normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
144
+ depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
145
+ depth = torch.from_numpy(depth).unsqueeze(0)
146
+ pose = input_cameras[idx]
147
+ pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
148
+
149
+ image_list.append(image)
150
+ alpha_list.append(alpha)
151
+ depth_list.append(depth)
152
+ normal_list.append(normal)
153
+ pose_list.append(pose)
154
+
155
+ target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
156
+ for idx in target_indices:
157
+ image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
158
+ normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
159
+ depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
160
+ depth = torch.from_numpy(depth).unsqueeze(0)
161
+ pose = target_cameras[idx]
162
+ pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
163
+
164
+ image_list.append(image)
165
+ alpha_list.append(alpha)
166
+ depth_list.append(depth)
167
+ normal_list.append(normal)
168
+ pose_list.append(pose)
169
+
170
+ except Exception as e:
171
+ print(e)
172
+ index = np.random.randint(0, len(self.paths))
173
+ continue
174
+
175
+ break
176
+
177
+ images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
178
+ alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
179
+ depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
180
+ normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
181
+ w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
182
+ c2ws = torch.linalg.inv(w2cs).float()
183
+
184
+ normals = normals * 2.0 - 1.0
185
+ normals = F.normalize(normals, dim=1)
186
+ normals = (normals + 1.0) / 2.0
187
+ normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
188
+
189
+ # random rotation along z axis
190
+ if self.camera_rotation:
191
+ degree = np.random.uniform(0, math.pi * 2)
192
+ rot = torch.tensor([
193
+ [np.cos(degree), -np.sin(degree), 0, 0],
194
+ [np.sin(degree), np.cos(degree), 0, 0],
195
+ [0, 0, 1, 0],
196
+ [0, 0, 0, 1],
197
+ ]).unsqueeze(0).float()
198
+ c2ws = torch.matmul(rot, c2ws)
199
+
200
+ # rotate normals
201
+ N, _, H, W = normals.shape
202
+ normals = normals * 2.0 - 1.0
203
+ normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
204
+ normals = F.normalize(normals, dim=1)
205
+ normals = (normals + 1.0) / 2.0
206
+ normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
207
+
208
+ # random scaling
209
+ if np.random.rand() < 0.5:
210
+ scale = np.random.uniform(0.8, 1.0)
211
+ c2ws[:, :3, 3] *= scale
212
+ depths *= scale
213
+
214
+ # instrinsics of perspective cameras
215
+ K = FOV_to_intrinsics(self.fov)
216
+ Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
217
+
218
+ data = {
219
+ 'input_images': images[:self.input_view_num], # (6, 3, H, W)
220
+ 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
221
+ 'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
222
+ 'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
223
+ 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
224
+ 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
225
+
226
+ # lrm generator input and supervision
227
+ 'target_images': images[self.input_view_num:], # (V, 3, H, W)
228
+ 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
229
+ 'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
230
+ 'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
231
+ 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
232
+ 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
233
+
234
+ 'depth_available': 1,
235
+ }
236
+ return data
237
+
238
+
239
+ class ValidationData(Dataset):
240
+ def __init__(self,
241
+ root_dir='objaverse/',
242
+ input_view_num=6,
243
+ input_image_size=256,
244
+ fov=50,
245
+ ):
246
+ self.root_dir = Path(root_dir)
247
+ self.input_view_num = input_view_num
248
+ self.input_image_size = input_image_size
249
+ self.fov = fov
250
+
251
+ self.paths = sorted(os.listdir(self.root_dir))
252
+ print('============= length of dataset %d =============' % len(self.paths))
253
+
254
+ cam_distance = 2.5
255
+ azimuths = np.array([30, 90, 150, 210, 270, 330])
256
+ elevations = np.array([30, -20, 30, -20, 30, -20])
257
+ azimuths = np.deg2rad(azimuths)
258
+ elevations = np.deg2rad(elevations)
259
+
260
+ x = cam_distance * np.cos(elevations) * np.cos(azimuths)
261
+ y = cam_distance * np.cos(elevations) * np.sin(azimuths)
262
+ z = cam_distance * np.sin(elevations)
263
+
264
+ cam_locations = np.stack([x, y, z], axis=-1)
265
+ cam_locations = torch.from_numpy(cam_locations).float()
266
+ c2ws = center_looking_at_camera_pose(cam_locations)
267
+ self.c2ws = c2ws.float()
268
+ self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
269
+
270
+ render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
271
+ render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
272
+ self.render_c2ws = render_c2ws.float()
273
+ self.render_Ks = render_Ks.float()
274
+
275
+ def __len__(self):
276
+ return len(self.paths)
277
+
278
+ def load_im(self, path, color):
279
+ """Replace background pixel with random color in rendering."""
280
+ pil_img = Image.open(path)
281
+ pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
282
+
283
+ image = np.asarray(pil_img, dtype=np.float32) / 255.
284
+ if image.shape[-1] == 4:
285
+ alpha = image[:, :, 3:]
286
+ image = image[:, :, :3] * alpha + color * (1 - alpha)
287
+ else:
288
+ alpha = np.ones_like(image[:, :, :1])
289
+
290
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
291
+ alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
292
+ return image, alpha
293
+
294
+ def __getitem__(self, index):
295
+ # load data
296
+ input_image_path = os.path.join(self.root_dir, self.paths[index])
297
+
298
+ '''background color, default: white'''
299
+ # color = np.random.uniform(0.48, 0.52)
300
+ bkg_color = [1.0, 1.0, 1.0]
301
+
302
+ image_list = []
303
+ alpha_list = []
304
+
305
+ for idx in range(self.input_view_num):
306
+ image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
307
+ image_list.append(image)
308
+ alpha_list.append(alpha)
309
+
310
+ images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
311
+ alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
312
+
313
+ data = {
314
+ 'input_images': images, # (6, 3, H, W)
315
+ 'input_alphas': alphas, # (6, 1, H, W)
316
+ 'input_c2ws': self.c2ws, # (6, 4, 4)
317
+ 'input_Ks': self.Ks, # (6, 3, 3)
318
+
319
+ 'render_c2ws': self.render_c2ws,
320
+ 'render_Ks': self.render_Ks,
321
+ }
322
+ return data
src/lib.rs DELETED
@@ -1 +0,0 @@
1
- //! Example of a Rust library.
 
 
src/main.cpp DELETED
@@ -1,8 +0,0 @@
1
- #include <cstdio>
2
-
3
- #include <rerun.hpp>
4
-
5
- int main(int argc, const char* argv[]) {
6
- printf("Hello, World!\n");
7
- return 0;
8
- }
 
 
 
 
 
 
 
 
 
src/main.rs DELETED
@@ -1,5 +0,0 @@
1
- //! Example of a Rust binary.
2
-
3
- fn main() {
4
- println!("Hello, PROJ_NAME!");
5
- }
 
 
 
 
 
 
src/model.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
11
+ from torchvision.transforms import v2
12
+ from torchvision.utils import make_grid, save_image
13
+
14
+ from src.utils.train_util import instantiate_from_config
15
+
16
+
17
+ class MVRecon(pl.LightningModule):
18
+ def __init__(
19
+ self,
20
+ lrm_generator_config,
21
+ lrm_path=None,
22
+ input_size=256,
23
+ render_size=192,
24
+ ):
25
+ super().__init__()
26
+
27
+ self.input_size = input_size
28
+ self.render_size = render_size
29
+
30
+ # init modules
31
+ self.lrm_generator = instantiate_from_config(lrm_generator_config)
32
+ if lrm_path is not None:
33
+ lrm_ckpt = torch.load(lrm_path)
34
+ self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
35
+
36
+ self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
37
+
38
+ self.validation_step_outputs = []
39
+
40
+ def on_fit_start(self):
41
+ if self.global_rank == 0:
42
+ os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
43
+ os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
44
+
45
+ def prepare_batch_data(self, batch):
46
+ lrm_generator_input = {}
47
+ render_gt = {} # for supervision
48
+
49
+ # input images
50
+ images = batch['input_images']
51
+ images = v2.functional.resize(
52
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
53
+
54
+ lrm_generator_input['images'] = images.to(self.device)
55
+
56
+ # input cameras and render cameras
57
+ input_c2ws = batch['input_c2ws'].flatten(-2)
58
+ input_Ks = batch['input_Ks'].flatten(-2)
59
+ target_c2ws = batch['target_c2ws'].flatten(-2)
60
+ target_Ks = batch['target_Ks'].flatten(-2)
61
+ render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
62
+ render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
63
+ render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
64
+
65
+ input_extrinsics = input_c2ws[:, :, :12]
66
+ input_intrinsics = torch.stack([
67
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
68
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
69
+ ], dim=-1)
70
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
71
+
72
+ # add noise to input cameras
73
+ cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
74
+
75
+ lrm_generator_input['cameras'] = cameras.to(self.device)
76
+ lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
77
+
78
+ # target images
79
+ target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
80
+ target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
81
+ target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
82
+
83
+ # random crop
84
+ render_size = np.random.randint(self.render_size, 513)
85
+ target_images = v2.functional.resize(
86
+ target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
87
+ target_depths = v2.functional.resize(
88
+ target_depths, render_size, interpolation=0, antialias=True)
89
+ target_alphas = v2.functional.resize(
90
+ target_alphas, render_size, interpolation=0, antialias=True)
91
+
92
+ crop_params = v2.RandomCrop.get_params(
93
+ target_images, output_size=(self.render_size, self.render_size))
94
+ target_images = v2.functional.crop(target_images, *crop_params)
95
+ target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
96
+ target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
97
+
98
+ lrm_generator_input['render_size'] = render_size
99
+ lrm_generator_input['crop_params'] = crop_params
100
+
101
+ render_gt['target_images'] = target_images.to(self.device)
102
+ render_gt['target_depths'] = target_depths.to(self.device)
103
+ render_gt['target_alphas'] = target_alphas.to(self.device)
104
+
105
+ return lrm_generator_input, render_gt
106
+
107
+ def prepare_validation_batch_data(self, batch):
108
+ lrm_generator_input = {}
109
+
110
+ # input images
111
+ images = batch['input_images']
112
+ images = v2.functional.resize(
113
+ images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
114
+
115
+ lrm_generator_input['images'] = images.to(self.device)
116
+
117
+ input_c2ws = batch['input_c2ws'].flatten(-2)
118
+ input_Ks = batch['input_Ks'].flatten(-2)
119
+
120
+ input_extrinsics = input_c2ws[:, :, :12]
121
+ input_intrinsics = torch.stack([
122
+ input_Ks[:, :, 0], input_Ks[:, :, 4],
123
+ input_Ks[:, :, 2], input_Ks[:, :, 5],
124
+ ], dim=-1)
125
+ cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
126
+
127
+ lrm_generator_input['cameras'] = cameras.to(self.device)
128
+
129
+ render_c2ws = batch['render_c2ws'].flatten(-2)
130
+ render_Ks = batch['render_Ks'].flatten(-2)
131
+ render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
132
+
133
+ lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
134
+ lrm_generator_input['render_size'] = 384
135
+ lrm_generator_input['crop_params'] = None
136
+
137
+ return lrm_generator_input
138
+
139
+ def forward_lrm_generator(
140
+ self,
141
+ images,
142
+ cameras,
143
+ render_cameras,
144
+ render_size=192,
145
+ crop_params=None,
146
+ chunk_size=1,
147
+ ):
148
+ planes = torch.utils.checkpoint.checkpoint(
149
+ self.lrm_generator.forward_planes,
150
+ images,
151
+ cameras,
152
+ use_reentrant=False,
153
+ )
154
+ frames = []
155
+ for i in range(0, render_cameras.shape[1], chunk_size):
156
+ frames.append(
157
+ torch.utils.checkpoint.checkpoint(
158
+ self.lrm_generator.synthesizer,
159
+ planes,
160
+ cameras=render_cameras[:, i:i+chunk_size],
161
+ render_size=render_size,
162
+ crop_params=crop_params,
163
+ use_reentrant=False
164
+ )
165
+ )
166
+ frames = {
167
+ k: torch.cat([r[k] for r in frames], dim=1)
168
+ for k in frames[0].keys()
169
+ }
170
+ return frames
171
+
172
+ def forward(self, lrm_generator_input):
173
+ images = lrm_generator_input['images']
174
+ cameras = lrm_generator_input['cameras']
175
+ render_cameras = lrm_generator_input['render_cameras']
176
+ render_size = lrm_generator_input['render_size']
177
+ crop_params = lrm_generator_input['crop_params']
178
+
179
+ out = self.forward_lrm_generator(
180
+ images,
181
+ cameras,
182
+ render_cameras,
183
+ render_size=render_size,
184
+ crop_params=crop_params,
185
+ chunk_size=1,
186
+ )
187
+ render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
188
+ render_depths = out['images_depth']
189
+ render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
190
+
191
+ out = {
192
+ 'render_images': render_images,
193
+ 'render_depths': render_depths,
194
+ 'render_alphas': render_alphas,
195
+ }
196
+ return out
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ lrm_generator_input, render_gt = self.prepare_batch_data(batch)
200
+
201
+ render_out = self.forward(lrm_generator_input)
202
+
203
+ loss, loss_dict = self.compute_loss(render_out, render_gt)
204
+
205
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
206
+
207
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
208
+ B, N, C, H, W = render_gt['target_images'].shape
209
+ N_in = lrm_generator_input['images'].shape[1]
210
+
211
+ input_images = v2.functional.resize(
212
+ lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
213
+ input_images = torch.cat(
214
+ [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
215
+
216
+ input_images = rearrange(
217
+ input_images, 'b n c h w -> b c h (n w)')
218
+ target_images = rearrange(
219
+ render_gt['target_images'], 'b n c h w -> b c h (n w)')
220
+ render_images = rearrange(
221
+ render_out['render_images'], 'b n c h w -> b c h (n w)')
222
+ target_alphas = rearrange(
223
+ repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
224
+ render_alphas = rearrange(
225
+ repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
226
+ target_depths = rearrange(
227
+ repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
228
+ render_depths = rearrange(
229
+ repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
230
+ MAX_DEPTH = torch.max(target_depths)
231
+ target_depths = target_depths / MAX_DEPTH * target_alphas
232
+ render_depths = render_depths / MAX_DEPTH
233
+
234
+ grid = torch.cat([
235
+ input_images,
236
+ target_images, render_images,
237
+ target_alphas, render_alphas,
238
+ target_depths, render_depths,
239
+ ], dim=-2)
240
+ grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
241
+
242
+ save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
243
+
244
+ return loss
245
+
246
+ def compute_loss(self, render_out, render_gt):
247
+ # NOTE: the rgb value range of OpenLRM is [0, 1]
248
+ render_images = render_out['render_images']
249
+ target_images = render_gt['target_images'].to(render_images)
250
+ render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
251
+ target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
252
+
253
+ loss_mse = F.mse_loss(render_images, target_images)
254
+ loss_lpips = 2.0 * self.lpips(render_images, target_images)
255
+
256
+ render_alphas = render_out['render_alphas']
257
+ target_alphas = render_gt['target_alphas']
258
+ loss_mask = F.mse_loss(render_alphas, target_alphas)
259
+
260
+ loss = loss_mse + loss_lpips + loss_mask
261
+
262
+ prefix = 'train'
263
+ loss_dict = {}
264
+ loss_dict.update({f'{prefix}/loss_mse': loss_mse})
265
+ loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
266
+ loss_dict.update({f'{prefix}/loss_mask': loss_mask})
267
+ loss_dict.update({f'{prefix}/loss': loss})
268
+
269
+ return loss, loss_dict
270
+
271
+ @torch.no_grad()
272
+ def validation_step(self, batch, batch_idx):
273
+ lrm_generator_input = self.prepare_validation_batch_data(batch)
274
+
275
+ render_out = self.forward(lrm_generator_input)
276
+ render_images = render_out['render_images']
277
+ render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
278
+
279
+ self.validation_step_outputs.append(render_images)
280
+
281
+ def on_validation_epoch_end(self):
282
+ images = torch.cat(self.validation_step_outputs, dim=-1)
283
+
284
+ all_images = self.all_gather(images)
285
+ all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
286
+
287
+ if self.global_rank == 0:
288
+ image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
289
+
290
+ grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
291
+ save_image(grid, image_path)
292
+ print(f"Saved image to {image_path}")
293
+
294
+ self.validation_step_outputs.clear()
295
+
296
+ def configure_optimizers(self):
297
+ lr = self.learning_rate
298
+
299
+ params = []
300
+
301
+ lrm_params_fast, lrm_params_slow = [], []
302
+ for n, p in self.lrm_generator.named_parameters():
303
+ if 'adaLN_modulation' in n or 'camera_embedder' in n:
304
+ lrm_params_fast.append(p)
305
+ else:
306
+ lrm_params_slow.append(p)
307
+ params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
308
+ params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
309
+
310
+ optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
311
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
312
+
313
+ return {'optimizer': optimizer, 'lr_scheduler': scheduler}