Spaces:
Sleeping
Sleeping
02alexander
commited on
Commit
·
71d5bf5
1
Parent(s):
344c16f
copy code to this repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- CMakeLists.txt +0 -18
- Cargo.lock +0 -7
- Cargo.toml +0 -198
- README.md +2 -39
- app.py +308 -0
- configs/instant-mesh-base.yaml +22 -0
- configs/instant-mesh-large.yaml +22 -0
- configs/instant-nerf-base.yaml +21 -0
- configs/instant-nerf-large.yaml +21 -0
- examples/bird.jpg +0 -0
- examples/bubble_mart_blue.png +0 -0
- examples/cake.jpg +0 -0
- examples/cartoon_dinosaur.png +0 -0
- examples/chair_armed.png +0 -0
- examples/chair_comfort.jpg +0 -0
- examples/chair_wood.jpg +0 -0
- examples/chest.jpg +0 -0
- examples/cute_horse.jpg +0 -0
- examples/cute_tiger.jpg +0 -0
- examples/earphone.jpg +0 -0
- examples/fox.jpg +0 -0
- examples/fruit.jpg +0 -0
- examples/fruit_elephant.jpg +0 -0
- examples/genshin_building.png +0 -0
- examples/genshin_teapot.png +0 -0
- examples/hatsune_miku.png +0 -0
- examples/house2.jpg +0 -0
- examples/mushroom_teapot.jpg +0 -0
- examples/pikachu.png +0 -0
- examples/plant.jpg +0 -0
- examples/robot.jpg +0 -0
- examples/sea_turtle.png +0 -0
- examples/skating_shoe.jpg +0 -0
- examples/sorting_board.png +0 -0
- examples/sword.png +0 -0
- examples/toy_car.jpg +0 -0
- examples/watermelon.png +0 -0
- examples/whitedog.png +0 -0
- examples/x_teapot.jpg +0 -0
- examples/x_toyduck.jpg +0 -0
- main.py +0 -11
- requirements.txt +27 -1
- src/__init__.py +0 -0
- src/data/__init__.py +0 -0
- src/data/objaverse.py +322 -0
- src/lib.rs +0 -1
- src/main.cpp +0 -8
- src/main.rs +0 -5
- src/model.py +313 -0
.gitignore
CHANGED
@@ -20,3 +20,5 @@ __pycache__
|
|
20 |
.mypy_cache
|
21 |
.ruff_cache
|
22 |
venv
|
|
|
|
|
|
20 |
.mypy_cache
|
21 |
.ruff_cache
|
22 |
venv
|
23 |
+
|
24 |
+
shell.nix
|
CMakeLists.txt
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
cmake_minimum_required(VERSION 3.16...3.27)
|
2 |
-
|
3 |
-
project(PROJ_NAME LANGUAGES CXX)
|
4 |
-
|
5 |
-
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
6 |
-
|
7 |
-
if(NOT DEFINED CMAKE_CXX_STANDARD)
|
8 |
-
set(CMAKE_CXX_STANDARD 17)
|
9 |
-
endif()
|
10 |
-
|
11 |
-
# Rerun:
|
12 |
-
include(FetchContent)
|
13 |
-
FetchContent_Declare(rerun_sdk URL https://github.com/rerun-io/rerun/releases/download/0.15.1/rerun_cpp_sdk.zip)
|
14 |
-
FetchContent_MakeAvailable(rerun_sdk)
|
15 |
-
|
16 |
-
add_executable(PROJ_NAME src/main.cpp)
|
17 |
-
target_link_libraries(PROJ_NAME rerun_sdk)
|
18 |
-
target_include_directories(PROJ_NAME PRIVATE src)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Cargo.lock
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
# This file is automatically @generated by Cargo.
|
2 |
-
# It is not intended for manual editing.
|
3 |
-
version = 3
|
4 |
-
|
5 |
-
[[package]]
|
6 |
-
name = "new_project_name"
|
7 |
-
version = "0.1.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Cargo.toml
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
[package]
|
2 |
-
authors = ["rerun.io <opensource@rerun.io>"]
|
3 |
-
categories = [] # TODO: fill in if you plan on publishing the crate
|
4 |
-
description = "" # TODO: fill in if you plan on publishing the crate
|
5 |
-
edition = "2021"
|
6 |
-
homepage = "https://github.com/rerun-io/new_repo_name"
|
7 |
-
include = ["LICENSE-APACHE", "LICENSE-MIT", "**/*.rs", "Cargo.toml"]
|
8 |
-
keywords = [] # TODO: fill in if you plan on publishing the crate
|
9 |
-
license = "MIT OR Apache-2.0"
|
10 |
-
name = "new_project_name"
|
11 |
-
publish = false # TODO: set to `true` if you plan on publishing the crate
|
12 |
-
readme = "README.md"
|
13 |
-
repository = "https://github.com/rerun-io/new_repo_name"
|
14 |
-
rust-version = "1.76"
|
15 |
-
version = "0.1.0"
|
16 |
-
|
17 |
-
[package.metadata.docs.rs]
|
18 |
-
all-features = true
|
19 |
-
targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
|
20 |
-
|
21 |
-
|
22 |
-
[features]
|
23 |
-
default = []
|
24 |
-
|
25 |
-
|
26 |
-
[dependencies]
|
27 |
-
|
28 |
-
|
29 |
-
[dev-dependencies]
|
30 |
-
|
31 |
-
|
32 |
-
[patch.crates-io]
|
33 |
-
|
34 |
-
|
35 |
-
[lints]
|
36 |
-
workspace = true
|
37 |
-
|
38 |
-
|
39 |
-
[workspace.lints.rust]
|
40 |
-
unsafe_code = "deny"
|
41 |
-
|
42 |
-
elided_lifetimes_in_paths = "warn"
|
43 |
-
future_incompatible = "warn"
|
44 |
-
nonstandard_style = "warn"
|
45 |
-
rust_2018_idioms = "warn"
|
46 |
-
rust_2021_prelude_collisions = "warn"
|
47 |
-
semicolon_in_expressions_from_macros = "warn"
|
48 |
-
trivial_numeric_casts = "warn"
|
49 |
-
unsafe_op_in_unsafe_fn = "warn" # `unsafe_op_in_unsafe_fn` may become the default in future Rust versions: https://github.com/rust-lang/rust/issues/71668
|
50 |
-
unused_extern_crates = "warn"
|
51 |
-
unused_import_braces = "warn"
|
52 |
-
unused_lifetimes = "warn"
|
53 |
-
|
54 |
-
trivial_casts = "allow"
|
55 |
-
unused_qualifications = "allow"
|
56 |
-
|
57 |
-
[workspace.lints.rustdoc]
|
58 |
-
all = "warn"
|
59 |
-
missing_crate_level_docs = "warn"
|
60 |
-
|
61 |
-
# See also clippy.toml
|
62 |
-
[workspace.lints.clippy]
|
63 |
-
as_ptr_cast_mut = "warn"
|
64 |
-
await_holding_lock = "warn"
|
65 |
-
bool_to_int_with_if = "warn"
|
66 |
-
char_lit_as_u8 = "warn"
|
67 |
-
checked_conversions = "warn"
|
68 |
-
clear_with_drain = "warn"
|
69 |
-
cloned_instead_of_copied = "warn"
|
70 |
-
dbg_macro = "warn"
|
71 |
-
debug_assert_with_mut_call = "warn"
|
72 |
-
derive_partial_eq_without_eq = "warn"
|
73 |
-
disallowed_macros = "warn" # See clippy.toml
|
74 |
-
disallowed_methods = "warn" # See clippy.toml
|
75 |
-
disallowed_names = "warn" # See clippy.toml
|
76 |
-
disallowed_script_idents = "warn" # See clippy.toml
|
77 |
-
disallowed_types = "warn" # See clippy.toml
|
78 |
-
doc_link_with_quotes = "warn"
|
79 |
-
doc_markdown = "warn"
|
80 |
-
empty_enum = "warn"
|
81 |
-
enum_glob_use = "warn"
|
82 |
-
equatable_if_let = "warn"
|
83 |
-
exit = "warn"
|
84 |
-
expl_impl_clone_on_copy = "warn"
|
85 |
-
explicit_deref_methods = "warn"
|
86 |
-
explicit_into_iter_loop = "warn"
|
87 |
-
explicit_iter_loop = "warn"
|
88 |
-
fallible_impl_from = "warn"
|
89 |
-
filter_map_next = "warn"
|
90 |
-
flat_map_option = "warn"
|
91 |
-
float_cmp_const = "warn"
|
92 |
-
fn_params_excessive_bools = "warn"
|
93 |
-
fn_to_numeric_cast_any = "warn"
|
94 |
-
from_iter_instead_of_collect = "warn"
|
95 |
-
get_unwrap = "warn"
|
96 |
-
if_let_mutex = "warn"
|
97 |
-
implicit_clone = "warn"
|
98 |
-
imprecise_flops = "warn"
|
99 |
-
index_refutable_slice = "warn"
|
100 |
-
inefficient_to_string = "warn"
|
101 |
-
infinite_loop = "warn"
|
102 |
-
into_iter_without_iter = "warn"
|
103 |
-
invalid_upcast_comparisons = "warn"
|
104 |
-
iter_not_returning_iterator = "warn"
|
105 |
-
iter_on_empty_collections = "warn"
|
106 |
-
iter_on_single_items = "warn"
|
107 |
-
iter_over_hash_type = "warn"
|
108 |
-
iter_without_into_iter = "warn"
|
109 |
-
large_digit_groups = "warn"
|
110 |
-
large_include_file = "warn"
|
111 |
-
large_stack_arrays = "warn"
|
112 |
-
large_stack_frames = "warn"
|
113 |
-
large_types_passed_by_value = "warn"
|
114 |
-
let_underscore_untyped = "warn"
|
115 |
-
let_unit_value = "warn"
|
116 |
-
linkedlist = "warn"
|
117 |
-
lossy_float_literal = "warn"
|
118 |
-
macro_use_imports = "warn"
|
119 |
-
manual_assert = "warn"
|
120 |
-
manual_clamp = "warn"
|
121 |
-
manual_instant_elapsed = "warn"
|
122 |
-
manual_let_else = "warn"
|
123 |
-
manual_ok_or = "warn"
|
124 |
-
manual_string_new = "warn"
|
125 |
-
map_err_ignore = "warn"
|
126 |
-
map_flatten = "warn"
|
127 |
-
map_unwrap_or = "warn"
|
128 |
-
match_on_vec_items = "warn"
|
129 |
-
match_same_arms = "warn"
|
130 |
-
match_wild_err_arm = "warn"
|
131 |
-
match_wildcard_for_single_variants = "warn"
|
132 |
-
mem_forget = "warn"
|
133 |
-
mismatched_target_os = "warn"
|
134 |
-
mismatching_type_param_order = "warn"
|
135 |
-
missing_assert_message = "warn"
|
136 |
-
missing_enforced_import_renames = "warn"
|
137 |
-
missing_errors_doc = "warn"
|
138 |
-
missing_safety_doc = "warn"
|
139 |
-
mut_mut = "warn"
|
140 |
-
mutex_integer = "warn"
|
141 |
-
needless_borrow = "warn"
|
142 |
-
needless_continue = "warn"
|
143 |
-
needless_for_each = "warn"
|
144 |
-
needless_pass_by_ref_mut = "warn"
|
145 |
-
needless_pass_by_value = "warn"
|
146 |
-
negative_feature_names = "warn"
|
147 |
-
nonstandard_macro_braces = "warn"
|
148 |
-
option_option = "warn"
|
149 |
-
path_buf_push_overwrite = "warn"
|
150 |
-
ptr_as_ptr = "warn"
|
151 |
-
ptr_cast_constness = "warn"
|
152 |
-
pub_without_shorthand = "warn"
|
153 |
-
rc_mutex = "warn"
|
154 |
-
readonly_write_lock = "warn"
|
155 |
-
redundant_type_annotations = "warn"
|
156 |
-
ref_option_ref = "warn"
|
157 |
-
rest_pat_in_fully_bound_structs = "warn"
|
158 |
-
same_functions_in_if_condition = "warn"
|
159 |
-
semicolon_if_nothing_returned = "warn"
|
160 |
-
should_panic_without_expect = "warn"
|
161 |
-
significant_drop_tightening = "warn"
|
162 |
-
single_match_else = "warn"
|
163 |
-
str_to_string = "warn"
|
164 |
-
string_add = "warn"
|
165 |
-
string_add_assign = "warn"
|
166 |
-
string_lit_as_bytes = "warn"
|
167 |
-
string_lit_chars_any = "warn"
|
168 |
-
string_to_string = "warn"
|
169 |
-
suspicious_command_arg_space = "warn"
|
170 |
-
suspicious_xor_used_as_pow = "warn"
|
171 |
-
todo = "warn"
|
172 |
-
too_many_lines = "warn"
|
173 |
-
trailing_empty_array = "warn"
|
174 |
-
trait_duplication_in_bounds = "warn"
|
175 |
-
tuple_array_conversions = "warn"
|
176 |
-
unchecked_duration_subtraction = "warn"
|
177 |
-
undocumented_unsafe_blocks = "warn"
|
178 |
-
unimplemented = "warn"
|
179 |
-
uninhabited_references = "warn"
|
180 |
-
uninlined_format_args = "warn"
|
181 |
-
unnecessary_box_returns = "warn"
|
182 |
-
unnecessary_safety_doc = "warn"
|
183 |
-
unnecessary_struct_initialization = "warn"
|
184 |
-
unnecessary_wraps = "warn"
|
185 |
-
unnested_or_patterns = "warn"
|
186 |
-
unused_peekable = "warn"
|
187 |
-
unused_rounding = "warn"
|
188 |
-
unused_self = "warn"
|
189 |
-
unwrap_used = "warn"
|
190 |
-
use_self = "warn"
|
191 |
-
useless_transmute = "warn"
|
192 |
-
verbose_file_reads = "warn"
|
193 |
-
wildcard_dependencies = "warn"
|
194 |
-
wildcard_imports = "warn"
|
195 |
-
zero_sized_map_values = "warn"
|
196 |
-
|
197 |
-
manual_range_contains = "allow" # this one is just worse imho
|
198 |
-
ref_patterns = "allow" # It's nice to avoid ref pattern, but there are some situations that are hard (impossible?) to express without.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,40 +1,3 @@
|
|
1 |
-
|
2 |
-
Template for our private and public repos, containing CI, CoC, etc
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
This template should be the default for any repository of any kind, including:
|
7 |
-
* Rust projects
|
8 |
-
* C++ projects
|
9 |
-
* Python projects
|
10 |
-
* Other stuff
|
11 |
-
|
12 |
-
This template includes
|
13 |
-
* License files
|
14 |
-
* Code of Conduct
|
15 |
-
* Helpers for checking and linting Rust code
|
16 |
-
- `cargo-clippy`
|
17 |
-
- `cargo-deny`
|
18 |
-
- `rust-toolchain`
|
19 |
-
- …
|
20 |
-
* CI for:
|
21 |
-
- Spell checking
|
22 |
-
- Link checking
|
23 |
-
- C++ checks
|
24 |
-
- Python checks
|
25 |
-
- Rust checks
|
26 |
-
|
27 |
-
|
28 |
-
## How to use
|
29 |
-
Start by clicking "Use this template" at https://github.com/rerun-io/rerun_template/ or follow [these instructions](https://docs.github.com/en/free-pro-team@latest/github/creating-cloning-and-archiving-repositories/creating-a-repository-from-a-template).
|
30 |
-
|
31 |
-
Then follow these steps:
|
32 |
-
* Run `scripts/template_update.py init --languages cpp,rust,python` to delete files you don't need (give the languages you need support for)
|
33 |
-
* Search and replace all instances of `new_repo_name` with the name of the repository.
|
34 |
-
* Search and replace all instances of `new_project_name` with the name of the project (crate/binary name).
|
35 |
-
* Search for `TODO` and fill in all those places
|
36 |
-
* Replace this `README.md` with something better
|
37 |
-
* Commit!
|
38 |
-
|
39 |
-
In the future you can always update this repository with the latest changes from the template by running:
|
40 |
-
* `scripts/template_update.py update --languages cpp,rust,python`
|
|
|
1 |
+
## Fork of the [InstantMesh space]() but with [Rerun](https://www.rerun.io) for visualization
|
|
|
2 |
|
3 |
+
The resulting Huggingface space can be found [here.](https://huggingface.co/spaces/rerun/InstantMesh)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import threading
|
6 |
+
from queue import SimpleQueue
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
import rembg
|
12 |
+
import rerun as rr
|
13 |
+
import rerun.blueprint as rrb
|
14 |
+
import spaces
|
15 |
+
import torch
|
16 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
17 |
+
from einops import rearrange
|
18 |
+
from gradio_rerun import Rerun
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
+
from omegaconf import OmegaConf
|
21 |
+
from PIL import Image
|
22 |
+
from pytorch_lightning import seed_everything
|
23 |
+
from torchvision.transforms import v2
|
24 |
+
|
25 |
+
from src.models.lrm_mesh import InstantMesh
|
26 |
+
from src.utils.camera_util import (
|
27 |
+
FOV_to_intrinsics,
|
28 |
+
get_circular_camera_poses,
|
29 |
+
get_zero123plus_input_cameras,
|
30 |
+
)
|
31 |
+
from src.utils.infer_util import remove_background, resize_foreground
|
32 |
+
from src.utils.train_util import instantiate_from_config
|
33 |
+
|
34 |
+
|
35 |
+
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
36 |
+
"""Get the rendering camera parameters."""
|
37 |
+
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
|
38 |
+
if is_flexicubes:
|
39 |
+
cameras = torch.linalg.inv(c2ws)
|
40 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
41 |
+
else:
|
42 |
+
extrinsics = c2ws.flatten(-2)
|
43 |
+
intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
|
44 |
+
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
45 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
46 |
+
return cameras
|
47 |
+
|
48 |
+
|
49 |
+
###############################################################################
|
50 |
+
# Configuration.
|
51 |
+
###############################################################################
|
52 |
+
|
53 |
+
|
54 |
+
def find_cuda():
|
55 |
+
# Check if CUDA_HOME or CUDA_PATH environment variables are set
|
56 |
+
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
|
57 |
+
|
58 |
+
if cuda_home and os.path.exists(cuda_home):
|
59 |
+
return cuda_home
|
60 |
+
|
61 |
+
# Search for the nvcc executable in the system's PATH
|
62 |
+
nvcc_path = shutil.which("nvcc")
|
63 |
+
|
64 |
+
if nvcc_path:
|
65 |
+
# Remove the 'bin/nvcc' part to get the CUDA installation path
|
66 |
+
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
|
67 |
+
return cuda_path
|
68 |
+
|
69 |
+
return None
|
70 |
+
|
71 |
+
|
72 |
+
cuda_path = find_cuda()
|
73 |
+
|
74 |
+
if cuda_path:
|
75 |
+
print(f"CUDA installation found at: {cuda_path}")
|
76 |
+
else:
|
77 |
+
print("CUDA installation not found")
|
78 |
+
|
79 |
+
config_path = "configs/instant-mesh-large.yaml"
|
80 |
+
config = OmegaConf.load(config_path)
|
81 |
+
config_name = os.path.basename(config_path).replace(".yaml", "")
|
82 |
+
model_config = config.model_config
|
83 |
+
infer_config = config.infer_config
|
84 |
+
|
85 |
+
IS_FLEXICUBES = True if config_name.startswith("instant-mesh") else False
|
86 |
+
|
87 |
+
device = torch.device("cuda")
|
88 |
+
|
89 |
+
# load diffusion model
|
90 |
+
print("Loading diffusion model ...")
|
91 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
92 |
+
"sudo-ai/zero123plus-v1.2",
|
93 |
+
custom_pipeline="zero123plus",
|
94 |
+
torch_dtype=torch.float16,
|
95 |
+
)
|
96 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
|
97 |
+
|
98 |
+
# load custom white-background UNet
|
99 |
+
unet_ckpt_path = hf_hub_download(
|
100 |
+
repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model"
|
101 |
+
)
|
102 |
+
state_dict = torch.load(unet_ckpt_path, map_location="cpu")
|
103 |
+
pipeline.unet.load_state_dict(state_dict, strict=True)
|
104 |
+
|
105 |
+
pipeline = pipeline.to(device)
|
106 |
+
print(f"type(pipeline)={type(pipeline)}")
|
107 |
+
|
108 |
+
# load reconstruction model
|
109 |
+
print("Loading reconstruction model ...")
|
110 |
+
model_ckpt_path = hf_hub_download(
|
111 |
+
repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model"
|
112 |
+
)
|
113 |
+
model: InstantMesh = instantiate_from_config(model_config)
|
114 |
+
state_dict = torch.load(model_ckpt_path, map_location="cpu")["state_dict"]
|
115 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith("lrm_generator.") and "source_camera" not in k}
|
116 |
+
model.load_state_dict(state_dict, strict=True)
|
117 |
+
|
118 |
+
model = model.to(device)
|
119 |
+
|
120 |
+
print("Loading Finished!")
|
121 |
+
|
122 |
+
|
123 |
+
def check_input_image(input_image):
|
124 |
+
if input_image is None:
|
125 |
+
raise gr.Error("No image uploaded!")
|
126 |
+
|
127 |
+
|
128 |
+
def preprocess(input_image, do_remove_background):
|
129 |
+
rembg_session = rembg.new_session() if do_remove_background else None
|
130 |
+
|
131 |
+
if do_remove_background:
|
132 |
+
input_image = remove_background(input_image, rembg_session)
|
133 |
+
input_image = resize_foreground(input_image, 0.85)
|
134 |
+
|
135 |
+
return input_image
|
136 |
+
|
137 |
+
|
138 |
+
def pipeline_callback(
|
139 |
+
log_queue: SimpleQueue, pipe: Any, step_index: int, timestep: float, callback_kwargs: dict[str, Any]
|
140 |
+
) -> dict[str, Any]:
|
141 |
+
latents = callback_kwargs["latents"]
|
142 |
+
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] # type: ignore[attr-defined]
|
143 |
+
image = pipe.image_processor.postprocess(image, output_type="np").squeeze() # type: ignore[attr-defined]
|
144 |
+
|
145 |
+
log_queue.put(("mvs", rr.Image(image)))
|
146 |
+
log_queue.put(("latents", rr.Tensor(latents.squeeze())))
|
147 |
+
|
148 |
+
return callback_kwargs
|
149 |
+
|
150 |
+
|
151 |
+
def generate_mvs(log_queue, input_image, sample_steps, sample_seed):
|
152 |
+
seed_everything(sample_seed)
|
153 |
+
|
154 |
+
return pipeline(
|
155 |
+
input_image,
|
156 |
+
num_inference_steps=sample_steps,
|
157 |
+
callback_on_step_end=lambda *args, **kwargs: pipeline_callback(log_queue, *args, **kwargs),
|
158 |
+
).images[0]
|
159 |
+
|
160 |
+
|
161 |
+
def make3d(log_queue, images: Image.Image):
|
162 |
+
global model
|
163 |
+
if IS_FLEXICUBES:
|
164 |
+
model.init_flexicubes_geometry(device, use_renderer=False)
|
165 |
+
model = model.eval()
|
166 |
+
|
167 |
+
images = np.asarray(images, dtype=np.float32) / 255.0
|
168 |
+
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
169 |
+
images = rearrange(images, "c (n h) (m w) -> (n m) c h w", n=3, m=2) # (6, 3, 320, 320)
|
170 |
+
|
171 |
+
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
|
172 |
+
|
173 |
+
images = images.unsqueeze(0).to(device)
|
174 |
+
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
175 |
+
|
176 |
+
with torch.no_grad():
|
177 |
+
# get triplane
|
178 |
+
planes = model.forward_planes(images, input_cameras)
|
179 |
+
|
180 |
+
# get mesh
|
181 |
+
mesh_out = model.extract_mesh(
|
182 |
+
planes,
|
183 |
+
use_texture_map=False,
|
184 |
+
**infer_config,
|
185 |
+
)
|
186 |
+
|
187 |
+
vertices, faces, vertex_colors = mesh_out
|
188 |
+
|
189 |
+
log_queue.put((
|
190 |
+
"mesh",
|
191 |
+
rr.Mesh3D(vertex_positions=vertices, vertex_colors=vertex_colors, triangle_indices=faces),
|
192 |
+
))
|
193 |
+
|
194 |
+
return mesh_out
|
195 |
+
|
196 |
+
|
197 |
+
def generate_blueprint() -> rrb.Blueprint:
|
198 |
+
return rrb.Blueprint(
|
199 |
+
rrb.Horizontal(
|
200 |
+
rrb.Spatial3DView(origin="mesh"),
|
201 |
+
rrb.Grid(
|
202 |
+
rrb.Spatial2DView(origin="z123image"),
|
203 |
+
rrb.Spatial2DView(origin="preprocessed_image"),
|
204 |
+
rrb.Spatial2DView(origin="mvs"),
|
205 |
+
rrb.TensorView(
|
206 |
+
origin="latents",
|
207 |
+
),
|
208 |
+
),
|
209 |
+
column_shares=[1, 1],
|
210 |
+
),
|
211 |
+
collapse_panels=True,
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
def compute(log_queue, input_image, do_remove_background, sample_steps, sample_seed):
|
216 |
+
preprocessed_image = preprocess(input_image, do_remove_background)
|
217 |
+
log_queue.put(("preprocessed_image", rr.Image(preprocessed_image)))
|
218 |
+
|
219 |
+
z123_image = generate_mvs(log_queue, preprocessed_image, sample_steps, sample_seed)
|
220 |
+
log_queue.put(("z123image", rr.Image(z123_image)))
|
221 |
+
|
222 |
+
_mesh_out = make3d(log_queue, z123_image)
|
223 |
+
|
224 |
+
log_queue.put("done")
|
225 |
+
|
226 |
+
|
227 |
+
@spaces.GPU
|
228 |
+
@rr.thread_local_stream("InstantMesh")
|
229 |
+
def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
|
230 |
+
log_queue = SimpleQueue()
|
231 |
+
|
232 |
+
stream = rr.binary_stream()
|
233 |
+
|
234 |
+
blueprint = generate_blueprint()
|
235 |
+
rr.send_blueprint(blueprint)
|
236 |
+
yield stream.read()
|
237 |
+
|
238 |
+
handle = threading.Thread(
|
239 |
+
target=compute, args=[log_queue, input_image, do_remove_background, sample_steps, sample_seed]
|
240 |
+
)
|
241 |
+
handle.start()
|
242 |
+
while True:
|
243 |
+
msg = log_queue.get()
|
244 |
+
if msg == "done":
|
245 |
+
break
|
246 |
+
else:
|
247 |
+
entity_path, entity = msg
|
248 |
+
rr.log(entity_path, entity)
|
249 |
+
yield stream.read()
|
250 |
+
handle.join()
|
251 |
+
|
252 |
+
|
253 |
+
_HEADER_ = """
|
254 |
+
<h2><b>Duplicate of the <a href='https://huggingface.co/spaces/TencentARC/InstantMesh'>InstantMesh space</a> that uses <a href='https://rerun.io/'>Rerun</a> for visualization.</b></h2>
|
255 |
+
<h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
|
256 |
+
|
257 |
+
**InstantMesh** is a feed-forward framework for efficient 3D mesh generation from a single image based on the LRM/Instant3D architecture.
|
258 |
+
|
259 |
+
Technical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
|
260 |
+
Source code: <a href='https://github.com/rerun-io/hf-example-instant-mesh'>Github</a>.
|
261 |
+
"""
|
262 |
+
|
263 |
+
with gr.Blocks() as demo:
|
264 |
+
gr.Markdown(_HEADER_)
|
265 |
+
with gr.Row(variant="panel"):
|
266 |
+
with gr.Column(scale=1):
|
267 |
+
with gr.Row():
|
268 |
+
input_image = gr.Image(
|
269 |
+
label="Input Image",
|
270 |
+
image_mode="RGBA",
|
271 |
+
sources="upload",
|
272 |
+
# width=256,
|
273 |
+
# height=256,
|
274 |
+
type="pil",
|
275 |
+
elem_id="content_image",
|
276 |
+
)
|
277 |
+
with gr.Row():
|
278 |
+
with gr.Group():
|
279 |
+
do_remove_background = gr.Checkbox(label="Remove Background", value=True)
|
280 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
281 |
+
|
282 |
+
sample_steps = gr.Slider(label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
|
283 |
+
|
284 |
+
with gr.Row():
|
285 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
286 |
+
|
287 |
+
with gr.Row(variant="panel"):
|
288 |
+
gr.Examples(
|
289 |
+
examples=[os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))],
|
290 |
+
inputs=[input_image],
|
291 |
+
label="Examples",
|
292 |
+
cache_examples=False,
|
293 |
+
examples_per_page=16,
|
294 |
+
)
|
295 |
+
|
296 |
+
with gr.Column(scale=2):
|
297 |
+
viewer = Rerun(streaming=True, height=800)
|
298 |
+
|
299 |
+
with gr.Row():
|
300 |
+
gr.Markdown("""Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).""")
|
301 |
+
|
302 |
+
mv_images = gr.State()
|
303 |
+
|
304 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
305 |
+
fn=log_to_rr, inputs=[input_image, do_remove_background, sample_steps, sample_seed], outputs=[viewer]
|
306 |
+
)
|
307 |
+
|
308 |
+
demo.launch()
|
configs/instant-mesh-base.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm_mesh.InstantMesh
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 12
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 40
|
13 |
+
rendering_samples_per_ray: 96
|
14 |
+
grid_res: 128
|
15 |
+
grid_scale: 2.1
|
16 |
+
|
17 |
+
|
18 |
+
infer_config:
|
19 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
+
model_path: ckpts/instant_mesh_base.ckpt
|
21 |
+
texture_resolution: 1024
|
22 |
+
render_resolution: 512
|
configs/instant-mesh-large.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm_mesh.InstantMesh
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 16
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 80
|
13 |
+
rendering_samples_per_ray: 128
|
14 |
+
grid_res: 128
|
15 |
+
grid_scale: 2.1
|
16 |
+
|
17 |
+
|
18 |
+
infer_config:
|
19 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
+
model_path: ckpts/instant_mesh_large.ckpt
|
21 |
+
texture_resolution: 1024
|
22 |
+
render_resolution: 512
|
configs/instant-nerf-base.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm.InstantNeRF
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 12
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 40
|
13 |
+
rendering_samples_per_ray: 96
|
14 |
+
|
15 |
+
|
16 |
+
infer_config:
|
17 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
18 |
+
model_path: ckpts/instant_nerf_base.ckpt
|
19 |
+
mesh_threshold: 10.0
|
20 |
+
mesh_resolution: 256
|
21 |
+
render_resolution: 384
|
configs/instant-nerf-large.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm.InstantNeRF
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 16
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 80
|
13 |
+
rendering_samples_per_ray: 128
|
14 |
+
|
15 |
+
|
16 |
+
infer_config:
|
17 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
18 |
+
model_path: ckpts/instant_nerf_large.ckpt
|
19 |
+
mesh_threshold: 10.0
|
20 |
+
mesh_resolution: 256
|
21 |
+
render_resolution: 384
|
examples/bird.jpg
ADDED
examples/bubble_mart_blue.png
ADDED
examples/cake.jpg
ADDED
examples/cartoon_dinosaur.png
ADDED
examples/chair_armed.png
ADDED
examples/chair_comfort.jpg
ADDED
examples/chair_wood.jpg
ADDED
examples/chest.jpg
ADDED
examples/cute_horse.jpg
ADDED
examples/cute_tiger.jpg
ADDED
examples/earphone.jpg
ADDED
examples/fox.jpg
ADDED
examples/fruit.jpg
ADDED
examples/fruit_elephant.jpg
ADDED
examples/genshin_building.png
ADDED
examples/genshin_teapot.png
ADDED
examples/hatsune_miku.png
ADDED
examples/house2.jpg
ADDED
examples/mushroom_teapot.jpg
ADDED
examples/pikachu.png
ADDED
examples/plant.jpg
ADDED
examples/robot.jpg
ADDED
examples/sea_turtle.png
ADDED
examples/skating_shoe.jpg
ADDED
examples/sorting_board.png
ADDED
examples/sword.png
ADDED
examples/toy_car.jpg
ADDED
examples/watermelon.png
ADDED
examples/whitedog.png
ADDED
examples/x_teapot.jpg
ADDED
examples/x_toyduck.jpg
ADDED
main.py
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
|
6 |
-
def main() -> None:
|
7 |
-
pass
|
8 |
-
|
9 |
-
|
10 |
-
if __name__ == "__main__":
|
11 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1 +1,27 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spaces
|
2 |
+
torch==2.1.0
|
3 |
+
torchvision==0.16.0
|
4 |
+
torchaudio==2.1.0
|
5 |
+
pytorch-lightning==2.1.2
|
6 |
+
einops
|
7 |
+
omegaconf
|
8 |
+
deepspeed
|
9 |
+
torchmetrics
|
10 |
+
webdataset
|
11 |
+
accelerate
|
12 |
+
tensorboard
|
13 |
+
PyMCubes
|
14 |
+
trimesh
|
15 |
+
rembg
|
16 |
+
transformers
|
17 |
+
diffusers==0.28.2
|
18 |
+
bitsandbytes
|
19 |
+
imageio[ffmpeg]
|
20 |
+
xatlas
|
21 |
+
plyfile
|
22 |
+
xformers==0.0.22.post7
|
23 |
+
git+https://github.com/NVlabs/nvdiffrast/
|
24 |
+
huggingface-hub
|
25 |
+
gradio_client >= 0.12
|
26 |
+
rerun-sdk>=0.16.0,<0.17.0
|
27 |
+
gradio_rerun
|
src/__init__.py
ADDED
File without changes
|
src/data/__init__.py
ADDED
File without changes
|
src/data/objaverse.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import webdataset as wds
|
14 |
+
from PIL import Image
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
from torch.utils.data.distributed import DistributedSampler
|
17 |
+
|
18 |
+
from src.utils.camera_util import (
|
19 |
+
FOV_to_intrinsics,
|
20 |
+
center_looking_at_camera_pose,
|
21 |
+
get_surrounding_views,
|
22 |
+
)
|
23 |
+
from src.utils.train_util import instantiate_from_config
|
24 |
+
|
25 |
+
|
26 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
batch_size=8,
|
30 |
+
num_workers=4,
|
31 |
+
train=None,
|
32 |
+
validation=None,
|
33 |
+
test=None,
|
34 |
+
**kwargs,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.batch_size = batch_size
|
39 |
+
self.num_workers = num_workers
|
40 |
+
|
41 |
+
self.dataset_configs = dict()
|
42 |
+
if train is not None:
|
43 |
+
self.dataset_configs['train'] = train
|
44 |
+
if validation is not None:
|
45 |
+
self.dataset_configs['validation'] = validation
|
46 |
+
if test is not None:
|
47 |
+
self.dataset_configs['test'] = test
|
48 |
+
|
49 |
+
def setup(self, stage):
|
50 |
+
|
51 |
+
if stage in ['fit']:
|
52 |
+
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
53 |
+
else:
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
def train_dataloader(self):
|
57 |
+
|
58 |
+
sampler = DistributedSampler(self.datasets['train'])
|
59 |
+
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
60 |
+
|
61 |
+
def val_dataloader(self):
|
62 |
+
|
63 |
+
sampler = DistributedSampler(self.datasets['validation'])
|
64 |
+
return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
65 |
+
|
66 |
+
def test_dataloader(self):
|
67 |
+
|
68 |
+
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
69 |
+
|
70 |
+
|
71 |
+
class ObjaverseData(Dataset):
|
72 |
+
def __init__(self,
|
73 |
+
root_dir='objaverse/',
|
74 |
+
meta_fname='valid_paths.json',
|
75 |
+
input_image_dir='rendering_random_32views',
|
76 |
+
target_image_dir='rendering_random_32views',
|
77 |
+
input_view_num=6,
|
78 |
+
target_view_num=2,
|
79 |
+
total_view_n=32,
|
80 |
+
fov=50,
|
81 |
+
camera_rotation=True,
|
82 |
+
validation=False,
|
83 |
+
):
|
84 |
+
self.root_dir = Path(root_dir)
|
85 |
+
self.input_image_dir = input_image_dir
|
86 |
+
self.target_image_dir = target_image_dir
|
87 |
+
|
88 |
+
self.input_view_num = input_view_num
|
89 |
+
self.target_view_num = target_view_num
|
90 |
+
self.total_view_n = total_view_n
|
91 |
+
self.fov = fov
|
92 |
+
self.camera_rotation = camera_rotation
|
93 |
+
|
94 |
+
with open(os.path.join(root_dir, meta_fname)) as f:
|
95 |
+
filtered_dict = json.load(f)
|
96 |
+
paths = filtered_dict['good_objs']
|
97 |
+
self.paths = paths
|
98 |
+
|
99 |
+
self.depth_scale = 4.0
|
100 |
+
|
101 |
+
len(self.paths)
|
102 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return len(self.paths)
|
106 |
+
|
107 |
+
def load_im(self, path, color):
|
108 |
+
"""Replace background pixel with random color in rendering."""
|
109 |
+
pil_img = Image.open(path)
|
110 |
+
|
111 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
112 |
+
alpha = image[:, :, 3:]
|
113 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
114 |
+
|
115 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
116 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
117 |
+
return image, alpha
|
118 |
+
|
119 |
+
def __getitem__(self, index):
|
120 |
+
# load data
|
121 |
+
while True:
|
122 |
+
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
|
123 |
+
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
|
124 |
+
|
125 |
+
indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
|
126 |
+
input_indices = indices[:self.input_view_num]
|
127 |
+
target_indices = indices[self.input_view_num:]
|
128 |
+
|
129 |
+
'''background color, default: white'''
|
130 |
+
bg_white = [1., 1., 1.]
|
131 |
+
bg_black = [0., 0., 0.]
|
132 |
+
|
133 |
+
image_list = []
|
134 |
+
alpha_list = []
|
135 |
+
depth_list = []
|
136 |
+
normal_list = []
|
137 |
+
pose_list = []
|
138 |
+
|
139 |
+
try:
|
140 |
+
input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
|
141 |
+
for idx in input_indices:
|
142 |
+
image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
|
143 |
+
normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
|
144 |
+
depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
|
145 |
+
depth = torch.from_numpy(depth).unsqueeze(0)
|
146 |
+
pose = input_cameras[idx]
|
147 |
+
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
|
148 |
+
|
149 |
+
image_list.append(image)
|
150 |
+
alpha_list.append(alpha)
|
151 |
+
depth_list.append(depth)
|
152 |
+
normal_list.append(normal)
|
153 |
+
pose_list.append(pose)
|
154 |
+
|
155 |
+
target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
|
156 |
+
for idx in target_indices:
|
157 |
+
image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
|
158 |
+
normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
|
159 |
+
depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
|
160 |
+
depth = torch.from_numpy(depth).unsqueeze(0)
|
161 |
+
pose = target_cameras[idx]
|
162 |
+
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
|
163 |
+
|
164 |
+
image_list.append(image)
|
165 |
+
alpha_list.append(alpha)
|
166 |
+
depth_list.append(depth)
|
167 |
+
normal_list.append(normal)
|
168 |
+
pose_list.append(pose)
|
169 |
+
|
170 |
+
except Exception as e:
|
171 |
+
print(e)
|
172 |
+
index = np.random.randint(0, len(self.paths))
|
173 |
+
continue
|
174 |
+
|
175 |
+
break
|
176 |
+
|
177 |
+
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
|
178 |
+
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
|
179 |
+
depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
|
180 |
+
normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
|
181 |
+
w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
|
182 |
+
c2ws = torch.linalg.inv(w2cs).float()
|
183 |
+
|
184 |
+
normals = normals * 2.0 - 1.0
|
185 |
+
normals = F.normalize(normals, dim=1)
|
186 |
+
normals = (normals + 1.0) / 2.0
|
187 |
+
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
|
188 |
+
|
189 |
+
# random rotation along z axis
|
190 |
+
if self.camera_rotation:
|
191 |
+
degree = np.random.uniform(0, math.pi * 2)
|
192 |
+
rot = torch.tensor([
|
193 |
+
[np.cos(degree), -np.sin(degree), 0, 0],
|
194 |
+
[np.sin(degree), np.cos(degree), 0, 0],
|
195 |
+
[0, 0, 1, 0],
|
196 |
+
[0, 0, 0, 1],
|
197 |
+
]).unsqueeze(0).float()
|
198 |
+
c2ws = torch.matmul(rot, c2ws)
|
199 |
+
|
200 |
+
# rotate normals
|
201 |
+
N, _, H, W = normals.shape
|
202 |
+
normals = normals * 2.0 - 1.0
|
203 |
+
normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
|
204 |
+
normals = F.normalize(normals, dim=1)
|
205 |
+
normals = (normals + 1.0) / 2.0
|
206 |
+
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
|
207 |
+
|
208 |
+
# random scaling
|
209 |
+
if np.random.rand() < 0.5:
|
210 |
+
scale = np.random.uniform(0.8, 1.0)
|
211 |
+
c2ws[:, :3, 3] *= scale
|
212 |
+
depths *= scale
|
213 |
+
|
214 |
+
# instrinsics of perspective cameras
|
215 |
+
K = FOV_to_intrinsics(self.fov)
|
216 |
+
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
|
217 |
+
|
218 |
+
data = {
|
219 |
+
'input_images': images[:self.input_view_num], # (6, 3, H, W)
|
220 |
+
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
|
221 |
+
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
|
222 |
+
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
|
223 |
+
'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
|
224 |
+
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
|
225 |
+
|
226 |
+
# lrm generator input and supervision
|
227 |
+
'target_images': images[self.input_view_num:], # (V, 3, H, W)
|
228 |
+
'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
|
229 |
+
'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
|
230 |
+
'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
|
231 |
+
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
|
232 |
+
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
|
233 |
+
|
234 |
+
'depth_available': 1,
|
235 |
+
}
|
236 |
+
return data
|
237 |
+
|
238 |
+
|
239 |
+
class ValidationData(Dataset):
|
240 |
+
def __init__(self,
|
241 |
+
root_dir='objaverse/',
|
242 |
+
input_view_num=6,
|
243 |
+
input_image_size=256,
|
244 |
+
fov=50,
|
245 |
+
):
|
246 |
+
self.root_dir = Path(root_dir)
|
247 |
+
self.input_view_num = input_view_num
|
248 |
+
self.input_image_size = input_image_size
|
249 |
+
self.fov = fov
|
250 |
+
|
251 |
+
self.paths = sorted(os.listdir(self.root_dir))
|
252 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
253 |
+
|
254 |
+
cam_distance = 2.5
|
255 |
+
azimuths = np.array([30, 90, 150, 210, 270, 330])
|
256 |
+
elevations = np.array([30, -20, 30, -20, 30, -20])
|
257 |
+
azimuths = np.deg2rad(azimuths)
|
258 |
+
elevations = np.deg2rad(elevations)
|
259 |
+
|
260 |
+
x = cam_distance * np.cos(elevations) * np.cos(azimuths)
|
261 |
+
y = cam_distance * np.cos(elevations) * np.sin(azimuths)
|
262 |
+
z = cam_distance * np.sin(elevations)
|
263 |
+
|
264 |
+
cam_locations = np.stack([x, y, z], axis=-1)
|
265 |
+
cam_locations = torch.from_numpy(cam_locations).float()
|
266 |
+
c2ws = center_looking_at_camera_pose(cam_locations)
|
267 |
+
self.c2ws = c2ws.float()
|
268 |
+
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
|
269 |
+
|
270 |
+
render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
|
271 |
+
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
|
272 |
+
self.render_c2ws = render_c2ws.float()
|
273 |
+
self.render_Ks = render_Ks.float()
|
274 |
+
|
275 |
+
def __len__(self):
|
276 |
+
return len(self.paths)
|
277 |
+
|
278 |
+
def load_im(self, path, color):
|
279 |
+
"""Replace background pixel with random color in rendering."""
|
280 |
+
pil_img = Image.open(path)
|
281 |
+
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
|
282 |
+
|
283 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
284 |
+
if image.shape[-1] == 4:
|
285 |
+
alpha = image[:, :, 3:]
|
286 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
287 |
+
else:
|
288 |
+
alpha = np.ones_like(image[:, :, :1])
|
289 |
+
|
290 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
291 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
292 |
+
return image, alpha
|
293 |
+
|
294 |
+
def __getitem__(self, index):
|
295 |
+
# load data
|
296 |
+
input_image_path = os.path.join(self.root_dir, self.paths[index])
|
297 |
+
|
298 |
+
'''background color, default: white'''
|
299 |
+
# color = np.random.uniform(0.48, 0.52)
|
300 |
+
bkg_color = [1.0, 1.0, 1.0]
|
301 |
+
|
302 |
+
image_list = []
|
303 |
+
alpha_list = []
|
304 |
+
|
305 |
+
for idx in range(self.input_view_num):
|
306 |
+
image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
|
307 |
+
image_list.append(image)
|
308 |
+
alpha_list.append(alpha)
|
309 |
+
|
310 |
+
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
|
311 |
+
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
|
312 |
+
|
313 |
+
data = {
|
314 |
+
'input_images': images, # (6, 3, H, W)
|
315 |
+
'input_alphas': alphas, # (6, 1, H, W)
|
316 |
+
'input_c2ws': self.c2ws, # (6, 4, 4)
|
317 |
+
'input_Ks': self.Ks, # (6, 3, 3)
|
318 |
+
|
319 |
+
'render_c2ws': self.render_c2ws,
|
320 |
+
'render_Ks': self.render_Ks,
|
321 |
+
}
|
322 |
+
return data
|
src/lib.rs
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
//! Example of a Rust library.
|
|
|
|
src/main.cpp
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
#include <cstdio>
|
2 |
-
|
3 |
-
#include <rerun.hpp>
|
4 |
-
|
5 |
-
int main(int argc, const char* argv[]) {
|
6 |
-
printf("Hello, World!\n");
|
7 |
-
return 0;
|
8 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main.rs
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
//! Example of a Rust binary.
|
2 |
-
|
3 |
-
fn main() {
|
4 |
-
println!("Hello, PROJ_NAME!");
|
5 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
src/model.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
11 |
+
from torchvision.transforms import v2
|
12 |
+
from torchvision.utils import make_grid, save_image
|
13 |
+
|
14 |
+
from src.utils.train_util import instantiate_from_config
|
15 |
+
|
16 |
+
|
17 |
+
class MVRecon(pl.LightningModule):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
lrm_generator_config,
|
21 |
+
lrm_path=None,
|
22 |
+
input_size=256,
|
23 |
+
render_size=192,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.input_size = input_size
|
28 |
+
self.render_size = render_size
|
29 |
+
|
30 |
+
# init modules
|
31 |
+
self.lrm_generator = instantiate_from_config(lrm_generator_config)
|
32 |
+
if lrm_path is not None:
|
33 |
+
lrm_ckpt = torch.load(lrm_path)
|
34 |
+
self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
|
35 |
+
|
36 |
+
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
|
37 |
+
|
38 |
+
self.validation_step_outputs = []
|
39 |
+
|
40 |
+
def on_fit_start(self):
|
41 |
+
if self.global_rank == 0:
|
42 |
+
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
43 |
+
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
44 |
+
|
45 |
+
def prepare_batch_data(self, batch):
|
46 |
+
lrm_generator_input = {}
|
47 |
+
render_gt = {} # for supervision
|
48 |
+
|
49 |
+
# input images
|
50 |
+
images = batch['input_images']
|
51 |
+
images = v2.functional.resize(
|
52 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
53 |
+
|
54 |
+
lrm_generator_input['images'] = images.to(self.device)
|
55 |
+
|
56 |
+
# input cameras and render cameras
|
57 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
58 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
59 |
+
target_c2ws = batch['target_c2ws'].flatten(-2)
|
60 |
+
target_Ks = batch['target_Ks'].flatten(-2)
|
61 |
+
render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
|
62 |
+
render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
|
63 |
+
render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
|
64 |
+
|
65 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
66 |
+
input_intrinsics = torch.stack([
|
67 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
68 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
69 |
+
], dim=-1)
|
70 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
71 |
+
|
72 |
+
# add noise to input cameras
|
73 |
+
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
|
74 |
+
|
75 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
76 |
+
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
|
77 |
+
|
78 |
+
# target images
|
79 |
+
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
|
80 |
+
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
|
81 |
+
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
|
82 |
+
|
83 |
+
# random crop
|
84 |
+
render_size = np.random.randint(self.render_size, 513)
|
85 |
+
target_images = v2.functional.resize(
|
86 |
+
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
|
87 |
+
target_depths = v2.functional.resize(
|
88 |
+
target_depths, render_size, interpolation=0, antialias=True)
|
89 |
+
target_alphas = v2.functional.resize(
|
90 |
+
target_alphas, render_size, interpolation=0, antialias=True)
|
91 |
+
|
92 |
+
crop_params = v2.RandomCrop.get_params(
|
93 |
+
target_images, output_size=(self.render_size, self.render_size))
|
94 |
+
target_images = v2.functional.crop(target_images, *crop_params)
|
95 |
+
target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
|
96 |
+
target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
|
97 |
+
|
98 |
+
lrm_generator_input['render_size'] = render_size
|
99 |
+
lrm_generator_input['crop_params'] = crop_params
|
100 |
+
|
101 |
+
render_gt['target_images'] = target_images.to(self.device)
|
102 |
+
render_gt['target_depths'] = target_depths.to(self.device)
|
103 |
+
render_gt['target_alphas'] = target_alphas.to(self.device)
|
104 |
+
|
105 |
+
return lrm_generator_input, render_gt
|
106 |
+
|
107 |
+
def prepare_validation_batch_data(self, batch):
|
108 |
+
lrm_generator_input = {}
|
109 |
+
|
110 |
+
# input images
|
111 |
+
images = batch['input_images']
|
112 |
+
images = v2.functional.resize(
|
113 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
114 |
+
|
115 |
+
lrm_generator_input['images'] = images.to(self.device)
|
116 |
+
|
117 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
118 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
119 |
+
|
120 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
121 |
+
input_intrinsics = torch.stack([
|
122 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
123 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
124 |
+
], dim=-1)
|
125 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
126 |
+
|
127 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
128 |
+
|
129 |
+
render_c2ws = batch['render_c2ws'].flatten(-2)
|
130 |
+
render_Ks = batch['render_Ks'].flatten(-2)
|
131 |
+
render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
|
132 |
+
|
133 |
+
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
|
134 |
+
lrm_generator_input['render_size'] = 384
|
135 |
+
lrm_generator_input['crop_params'] = None
|
136 |
+
|
137 |
+
return lrm_generator_input
|
138 |
+
|
139 |
+
def forward_lrm_generator(
|
140 |
+
self,
|
141 |
+
images,
|
142 |
+
cameras,
|
143 |
+
render_cameras,
|
144 |
+
render_size=192,
|
145 |
+
crop_params=None,
|
146 |
+
chunk_size=1,
|
147 |
+
):
|
148 |
+
planes = torch.utils.checkpoint.checkpoint(
|
149 |
+
self.lrm_generator.forward_planes,
|
150 |
+
images,
|
151 |
+
cameras,
|
152 |
+
use_reentrant=False,
|
153 |
+
)
|
154 |
+
frames = []
|
155 |
+
for i in range(0, render_cameras.shape[1], chunk_size):
|
156 |
+
frames.append(
|
157 |
+
torch.utils.checkpoint.checkpoint(
|
158 |
+
self.lrm_generator.synthesizer,
|
159 |
+
planes,
|
160 |
+
cameras=render_cameras[:, i:i+chunk_size],
|
161 |
+
render_size=render_size,
|
162 |
+
crop_params=crop_params,
|
163 |
+
use_reentrant=False
|
164 |
+
)
|
165 |
+
)
|
166 |
+
frames = {
|
167 |
+
k: torch.cat([r[k] for r in frames], dim=1)
|
168 |
+
for k in frames[0].keys()
|
169 |
+
}
|
170 |
+
return frames
|
171 |
+
|
172 |
+
def forward(self, lrm_generator_input):
|
173 |
+
images = lrm_generator_input['images']
|
174 |
+
cameras = lrm_generator_input['cameras']
|
175 |
+
render_cameras = lrm_generator_input['render_cameras']
|
176 |
+
render_size = lrm_generator_input['render_size']
|
177 |
+
crop_params = lrm_generator_input['crop_params']
|
178 |
+
|
179 |
+
out = self.forward_lrm_generator(
|
180 |
+
images,
|
181 |
+
cameras,
|
182 |
+
render_cameras,
|
183 |
+
render_size=render_size,
|
184 |
+
crop_params=crop_params,
|
185 |
+
chunk_size=1,
|
186 |
+
)
|
187 |
+
render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
|
188 |
+
render_depths = out['images_depth']
|
189 |
+
render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
|
190 |
+
|
191 |
+
out = {
|
192 |
+
'render_images': render_images,
|
193 |
+
'render_depths': render_depths,
|
194 |
+
'render_alphas': render_alphas,
|
195 |
+
}
|
196 |
+
return out
|
197 |
+
|
198 |
+
def training_step(self, batch, batch_idx):
|
199 |
+
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
|
200 |
+
|
201 |
+
render_out = self.forward(lrm_generator_input)
|
202 |
+
|
203 |
+
loss, loss_dict = self.compute_loss(render_out, render_gt)
|
204 |
+
|
205 |
+
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
206 |
+
|
207 |
+
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
208 |
+
B, N, C, H, W = render_gt['target_images'].shape
|
209 |
+
N_in = lrm_generator_input['images'].shape[1]
|
210 |
+
|
211 |
+
input_images = v2.functional.resize(
|
212 |
+
lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
|
213 |
+
input_images = torch.cat(
|
214 |
+
[input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
|
215 |
+
|
216 |
+
input_images = rearrange(
|
217 |
+
input_images, 'b n c h w -> b c h (n w)')
|
218 |
+
target_images = rearrange(
|
219 |
+
render_gt['target_images'], 'b n c h w -> b c h (n w)')
|
220 |
+
render_images = rearrange(
|
221 |
+
render_out['render_images'], 'b n c h w -> b c h (n w)')
|
222 |
+
target_alphas = rearrange(
|
223 |
+
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
224 |
+
render_alphas = rearrange(
|
225 |
+
repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
226 |
+
target_depths = rearrange(
|
227 |
+
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
228 |
+
render_depths = rearrange(
|
229 |
+
repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
230 |
+
MAX_DEPTH = torch.max(target_depths)
|
231 |
+
target_depths = target_depths / MAX_DEPTH * target_alphas
|
232 |
+
render_depths = render_depths / MAX_DEPTH
|
233 |
+
|
234 |
+
grid = torch.cat([
|
235 |
+
input_images,
|
236 |
+
target_images, render_images,
|
237 |
+
target_alphas, render_alphas,
|
238 |
+
target_depths, render_depths,
|
239 |
+
], dim=-2)
|
240 |
+
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
|
241 |
+
|
242 |
+
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
|
243 |
+
|
244 |
+
return loss
|
245 |
+
|
246 |
+
def compute_loss(self, render_out, render_gt):
|
247 |
+
# NOTE: the rgb value range of OpenLRM is [0, 1]
|
248 |
+
render_images = render_out['render_images']
|
249 |
+
target_images = render_gt['target_images'].to(render_images)
|
250 |
+
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
251 |
+
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
252 |
+
|
253 |
+
loss_mse = F.mse_loss(render_images, target_images)
|
254 |
+
loss_lpips = 2.0 * self.lpips(render_images, target_images)
|
255 |
+
|
256 |
+
render_alphas = render_out['render_alphas']
|
257 |
+
target_alphas = render_gt['target_alphas']
|
258 |
+
loss_mask = F.mse_loss(render_alphas, target_alphas)
|
259 |
+
|
260 |
+
loss = loss_mse + loss_lpips + loss_mask
|
261 |
+
|
262 |
+
prefix = 'train'
|
263 |
+
loss_dict = {}
|
264 |
+
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
|
265 |
+
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
|
266 |
+
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
|
267 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
268 |
+
|
269 |
+
return loss, loss_dict
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def validation_step(self, batch, batch_idx):
|
273 |
+
lrm_generator_input = self.prepare_validation_batch_data(batch)
|
274 |
+
|
275 |
+
render_out = self.forward(lrm_generator_input)
|
276 |
+
render_images = render_out['render_images']
|
277 |
+
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
|
278 |
+
|
279 |
+
self.validation_step_outputs.append(render_images)
|
280 |
+
|
281 |
+
def on_validation_epoch_end(self):
|
282 |
+
images = torch.cat(self.validation_step_outputs, dim=-1)
|
283 |
+
|
284 |
+
all_images = self.all_gather(images)
|
285 |
+
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
286 |
+
|
287 |
+
if self.global_rank == 0:
|
288 |
+
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
|
289 |
+
|
290 |
+
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
|
291 |
+
save_image(grid, image_path)
|
292 |
+
print(f"Saved image to {image_path}")
|
293 |
+
|
294 |
+
self.validation_step_outputs.clear()
|
295 |
+
|
296 |
+
def configure_optimizers(self):
|
297 |
+
lr = self.learning_rate
|
298 |
+
|
299 |
+
params = []
|
300 |
+
|
301 |
+
lrm_params_fast, lrm_params_slow = [], []
|
302 |
+
for n, p in self.lrm_generator.named_parameters():
|
303 |
+
if 'adaLN_modulation' in n or 'camera_embedder' in n:
|
304 |
+
lrm_params_fast.append(p)
|
305 |
+
else:
|
306 |
+
lrm_params_slow.append(p)
|
307 |
+
params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
|
308 |
+
params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
|
309 |
+
|
310 |
+
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
|
311 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
|
312 |
+
|
313 |
+
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|