diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..d05f48a96da48495a45c4b02288fb020515085b8
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,57 @@
+# Ignore Python cache files
+**/__pycache__
+
+# Ignore compiled Python files
+*.pyc
+
+# Ignore editor-specific files
+.vscode/
+.idea/
+
+# Ignore operating system files
+.DS_Store
+Thumbs.db
+
+# Ignore log files
+*.log
+
+# Ignore temporary and cache files
+*.tmp
+*.cache
+
+# Ignore build artifacts
+/build/
+/dist/
+
+# Ignore virtual environment files
+/venv/
+/.venv/
+
+# Ignore package manager files
+/node_modules/
+/yarn.lock
+/package-lock.json
+
+# Ignore database files
+*.db
+*.sqlite
+
+# Ignore secret files
+*.secret
+
+# Ignore compiled binaries
+*.exe
+*.dll
+*.so
+*.dylib
+
+# Ignore backup files
+*.bak
+*.swp
+*.swo
+*.~*
+
+
+# Ignore exp result files
+dumps/
+exps/
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed2a457f756abfa8faa67d901bbecca93a37f7e6
--- /dev/null
+++ b/app.py
@@ -0,0 +1,394 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from PIL import Image
+import numpy as np
+import gradio as gr
+
+
+def assert_input_image(input_front_image, input_back_image):
+ if input_front_image is None:
+ raise gr.Error("No front image selected or uploaded!")
+ if input_back_image is None:
+ raise gr.Error("No back image selected or uploaded!")
+
+def prepare_working_dir():
+ import tempfile
+ working_dir = tempfile.TemporaryDirectory()
+ return working_dir
+
+def init_preprocessor():
+ from openlrm.utils.preprocess import Preprocessor
+ global preprocessor
+ preprocessor = Preprocessor()
+
+def preprocess_fn(image_in_front: np.ndarray, image_in_back: np.ndarray, remove_bg: bool, recenter: bool, working_dir):
+ # save front image first
+ image_raw_front = os.path.join(working_dir.name, "raw_front.png")
+ with Image.fromarray(image_in_front) as img:
+ img.save(image_raw_front)
+ image_out_front = os.path.join(working_dir.name, "front/rembg_front.png")
+
+ # save back image first
+ image_raw_back = os.path.join(working_dir.name, "raw_back.png")
+ with Image.fromarray(image_in_back) as img:
+ img.save(image_raw_back)
+ image_out_back = os.path.join(working_dir.name, "back/rembg_back.png")
+
+ # process the front and back image.
+ success_front = preprocessor.preprocess(image_path=image_raw_front, save_path=image_out_front, rmbg=remove_bg, recenter=recenter)
+ success_back = preprocessor.preprocess(image_path=image_raw_back, save_path=image_out_back, rmbg=remove_bg, recenter=recenter)
+ assert success_front and success_back, f"Failed under preprocess_fn!"
+ return image_out_front, image_out_back
+
+
+def demo_openlrm(infer_impl):
+
+ def core_fn(image_front: str, image_back: str, source_cam_dist: float, working_dir):
+ dump_video_path = os.path.join(working_dir.name, "output.mp4")
+ dump_mesh_path = os.path.join(working_dir.name, "output.ply")
+ infer_impl(
+ image_path=image_front,
+ source_cam_dist=source_cam_dist,
+ export_video=True,
+ export_mesh=False,
+ dump_video_path=dump_video_path,
+ dump_mesh_path=dump_mesh_path,
+ image_path_back=image_back,
+ )
+ return dump_video_path
+
+ def example_fn(input_front_image: np.ndarray, input_back_image: np.ndarray):
+ from gradio.utils import get_cache_folder
+ working_dir = get_cache_folder()
+ processed_front_image, processed_back_image = preprocess_fn(
+ image_in_front=input_front_image,
+ image_in_back=input_back_image,
+ remove_bg=True,
+ recenter=True,
+ working_dir=working_dir,
+ )
+ video = core_fn(
+ image_front=processed_front_image,
+ image_back=processed_back_image,
+ source_cam_dist=2.0,
+ working_dir=working_dir,
+ )
+ return processed_front_image, processed_back_image, video
+
+ _TITLE = '''🔥 🔥 🔥 Tailor3D: Customized 3D Assets Editing and Generation with Dual-Side Images'''
+
+ _DESCRIPTION = '''
+
+
+
+
+ We propose Tailor3D, a novel pipeline creating customized 3D assets from editable dual-side images and feed-forward reconstruction methods.
+
+ Here we show the final step of Tailor3D. That is given the edited front and beck view of the object. We can produce the 3D object with several seconds.
+
+ Disclaimer: This demo uses `Tailor3D-base-1.1` model with 288x288 rendering resolution here for a quick demonstration.
+ '''
+
+ with gr.Blocks(analytics_enabled=False) as demo:
+
+ # HEADERS
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown('# ' + _TITLE)
+ with gr.Row():
+ gr.Markdown(_DESCRIPTION)
+
+ # DISPLAY
+ with gr.Row():
+ gr.Markdown(
+ """
+ ## 🖼️ Input: This is the input front and back images.
+ """
+ )
+ with gr.Row():
+ with gr.Column(variant='panel', scale=0.2):
+ with gr.Tabs(elem_id="tailor3d_input_front_image"):
+ with gr.TabItem('Input Front-view Image'):
+ with gr.Row():
+ input_front_image = gr.Image(label="Input Front Image", image_mode="RGBA", width="auto", sources="upload", type="numpy", elem_id="content_image")
+
+ with gr.Column(variant='panel', scale=0.2):
+ with gr.Tabs(elem_id="tailor3d_input_back_image"):
+ with gr.TabItem('Input Back-view Image'):
+ with gr.Row():
+ input_back_image = gr.Image(label="Input Back Image", image_mode="RGBA", width="auto", sources="upload", type="numpy", elem_id="content_image")
+ with gr.Row():
+ gr.Markdown(
+ """
+ ## 🛠️ Preprocess: Remove the background and center the object.
+ """
+ )
+ with gr.Row():
+ with gr.Column(variant='panel', scale=0.2):
+ with gr.Tabs(elem_id="tailor3d_processed_image"):
+ with gr.TabItem('Processed Front-view Image'):
+ with gr.Row():
+ processed_front_image = gr.Image(label="Processed Image", image_mode="RGBA", type="filepath", elem_id="processed_image", width="auto", interactive=False)
+ with gr.Column(variant='panel', scale=0.2):
+ with gr.Tabs(elem_id="tailor3d_processed_image"):
+ with gr.TabItem('Processed Back-view Image'):
+ with gr.Row():
+ processed_back_image = gr.Image(label="Processed Image", image_mode="RGBA", type="filepath", elem_id="processed_image", width="auto", interactive=False)
+ with gr.Row():
+ gr.Markdown(
+ """
+ ## 🚀 Output: The rendering video of the 3D object.
+ Note that the output is the 3D mesh, for convience, we showcase it through a video that circles around.
+ """
+ )
+ with gr.Row():
+ with gr.Column(variant='panel', scale=0.2):
+ with gr.Tabs(elem_id="tailor3d_render_video"):
+ with gr.TabItem('Rendered Video'):
+ with gr.Row():
+ output_video = gr.Video(label="Rendered Video", format="mp4", width="auto", autoplay=True)
+
+ # SETTING
+ with gr.Row():
+ with gr.Column(variant='panel', scale=1):
+ with gr.Tabs(elem_id="openlrm_attrs"):
+ with gr.TabItem('Settings'):
+ with gr.Column(variant='panel'):
+ gr.Markdown(
+ """
+ Best Practice:
+ Centered objects in reasonable sizes. Try adjusting source camera distances.
+ """
+ )
+ checkbox_rembg = gr.Checkbox(True, label='Remove background')
+ checkbox_recenter = gr.Checkbox(True, label='Recenter the object')
+ slider_cam_dist = gr.Slider(1.0, 3.5, value=2.0, step=0.1, label="Source Camera Distance")
+ submit = gr.Button('Generate', elem_id="openlrm_generate", variant='primary')
+
+ # EXAMPLES
+ with gr.Row():
+ gr.Markdown(
+ """
+ ## Example in the paper.
+ ### A. 3D Style Transfer
+ Here we keep the object ID and just transfer the style.
+
+ **Line 1: A pop-mart boy with astronaut, blue, traditional Chinese and grey style.**
+ """
+ )
+ with gr.Row():
+ examples = [
+ ['assets/sample_input/demo/front/boy_astronaut.png', 'assets/sample_input/demo/back/boy_astronaut.png'],
+ ['assets/sample_input/demo/front/boy_blue.png', 'assets/sample_input/demo/back/boy_blue.png'],
+ ['assets/sample_input/demo/front/boy_chinese_style.png', 'assets/sample_input/demo/back/boy_chinese_style.png'],
+ ['assets/sample_input/demo/front/boy_grey_clothes.png', 'assets/sample_input/demo/back/boy_grey_clothes.png'],
+ ]
+
+ for example in examples:
+ with gr.Column(scale=1):
+ gr.Examples(
+ examples=[example],
+ inputs=[input_front_image, input_back_image],
+ outputs=[processed_front_image, processed_back_image, output_video],
+ fn=example_fn,
+ cache_examples=bool(os.getenv('SPACE_ID')),
+ examples_per_page=3,
+ )
+
+ # # EXAMPLES
+ # with gr.Row():
+ # gr.Markdown(
+ # """
+ # **Line 2: A LEGO model featuring an astronaut, green and red elements, and a wizard theme.**
+ # """
+ # )
+ # with gr.Row():
+ # examples = [
+ # ['assets/sample_input/demo/front/lego_astronaut.png', 'assets/sample_input/demo/back/lego_astronaut.png'],
+ # ['assets/sample_input/demo/front/lego_green.png', 'assets/sample_input/demo/back/lego_green.png'],
+ # ['assets/sample_input/demo/front/lego_red.png', 'assets/sample_input/demo/front/lego_red.png'],
+ # ['assets/sample_input/demo/front/lego_wizard.png', 'assets/sample_input/demo/back/lego_wizard.png'],
+ # ]
+
+ # for example in examples:
+ # with gr.Column(scale=0.3):
+ # gr.Examples(
+ # examples=[example],
+ # inputs=[input_front_image, input_back_image],
+ # outputs=None, # [processed_image, output_video],
+ # fn=None, # example_fn,
+ # cache_examples=bool(os.getenv('SPACE_ID')),
+ # examples_per_page=3,
+ # )
+ # with gr.Row():
+ # gr.Markdown(
+ # """
+ # **Line 3: A marvel boy featuring an Captain America, Ironman and Spiderman, and a Superman theme.**
+ # """
+ # )
+ # with gr.Row():
+ # examples = [
+ # ['assets/sample_input/demo/front/marvel_captain.png', 'assets/sample_input/demo/back/marvel_captain.png'],
+ # ['assets/sample_input/demo/front/marvel_ironman.png', 'assets/sample_input/demo/front/marvel_ironman.png'],
+ # ['assets/sample_input/demo/front/marvel_spiderman.png', 'assets/sample_input/demo/back/marvel_spiderman.png'],
+ # ['assets/sample_input/demo/front/marvel_superman.png', 'assets/sample_input/demo/back/marvel_superman.png'],
+ # ]
+
+ # for example in examples:
+ # with gr.Column(scale=0.3):
+ # gr.Examples(
+ # examples=[example],
+ # inputs=[input_front_image, input_back_image],
+ # outputs=None, # [processed_image, output_video],
+ # fn=None, # example_fn,
+ # cache_examples=bool(os.getenv('SPACE_ID')),
+ # examples_per_page=3,
+ # )
+ # # EXAMPLES
+ # with gr.Row():
+ # gr.Markdown(
+ # """
+ # ### B. 3D Generative Geometry or Pattern Fill
+
+ # Here, we start with a simple object and gradually add various accessories, costumes, or patterns step by step. We only showcase the final effect after multiple rounds of decoration.
+
+ # **Line 4: Initial object: sofa, dog, penguin, house.**
+ # """
+ # )
+ # with gr.Row():
+ # examples = [
+ # ['assets/sample_input/demo/front/sofa.png', 'assets/sample_input/demo/back/sofa.png'],
+ # ['assets/sample_input/demo/front/penguin.png', 'assets/sample_input/demo/back/penguin.png'],
+ # ['assets/sample_input/demo/front/house.png', 'assets/sample_input/demo/back/house.png'],
+ # ]
+
+ # for example in examples:
+ # with gr.Column(scale=0.3):
+ # gr.Examples(
+ # examples=[example],
+ # inputs=[input_front_image, input_back_image],
+ # outputs=None, # [processed_image, output_video],
+ # fn=None, # example_fn,
+ # cache_examples=bool(os.getenv('SPACE_ID')),
+ # examples_per_page=3,
+ # )
+
+ # with gr.Row():
+ # gr.Markdown(
+ # """
+ # ### C. 3D Style Fusion
+
+ # We will maintain a consistent front style of the object while continuously changing the back style, blending the two different styles into one object.
+
+ # **Line 5: A bird with different back styles.**
+ # """
+ # )
+ # with gr.Row():
+ # examples = [
+ # ['assets/sample_input/demo/front/bird.png', 'assets/sample_input/demo/back/bird.png'],
+ # ['assets/sample_input/demo/front/bird_brownblue.png', 'assets/sample_input/demo/back/bird_brownblue.png'],
+ # ['assets/sample_input/demo/front/bird_rainbow.png', 'assets/sample_input/demo/back/bird_rainbow.png'],
+ # ['assets/sample_input/demo/front/bird_whitered.png', 'assets/sample_input/demo/back/bird_whitered.png'],
+ # ]
+
+ # for example in examples:
+ # with gr.Column(scale=0.3):
+ # gr.Examples(
+ # examples=[example],
+ # inputs=[input_front_image, input_back_image],
+ # outputs=None, # [processed_image, output_video],
+ # fn=None, # example_fn,
+ # cache_examples=bool(os.getenv('SPACE_ID')),
+ # examples_per_page=3,
+ # )
+
+ # with gr.Row():
+ # gr.Markdown(
+ # """
+ # ### Others
+ # I vote for kunkun forever, I am really an I-kUN and have heard many of his songs.
+ # """
+ # )
+ # with gr.Row():
+ # examples = [
+ # ['assets/sample_input/demo/front/loopy.png', 'assets/sample_input/demo/back/loopy.png'],
+ # ['assets/sample_input/demo/front/mario.png', 'assets/sample_input/demo/back/mario.png'],
+ # ['assets/sample_input/demo/front/armor.png', 'assets/sample_input/demo/back/armor.png'],
+ # ['assets/sample_input/demo/front/kunkun_law.png', 'assets/sample_input/demo/back/kunkun_law.png'],
+ # ]
+
+ # for example in examples:
+ # with gr.Column(scale=0.3):
+ # gr.Examples(
+ # examples=[example],
+ # inputs=[input_front_image, input_back_image],
+ # outputs=None, # [processed_image, output_video],
+ # fn=None, # example_fn,
+ # cache_examples=bool(os.getenv('SPACE_ID')),
+ # examples_per_page=3,
+ # )
+
+ working_dir = gr.State()
+ submit.click(
+ fn=assert_input_image,
+ inputs=[input_front_image, input_back_image],
+ queue=False,
+ ).success(
+ fn=prepare_working_dir,
+ outputs=[working_dir],
+ queue=False,
+ ).success(
+ fn=preprocess_fn,
+ inputs=[input_front_image, input_back_image, checkbox_rembg, checkbox_recenter, working_dir],
+ outputs=[processed_front_image, processed_back_image],
+ ).success(
+ fn=core_fn,
+ inputs=[processed_front_image, processed_back_image, slider_cam_dist, working_dir],
+ outputs=[output_video],
+ )
+
+ demo.queue()
+ demo.launch()
+
+
+def launch_gradio_app():
+
+ os.environ.update({
+ "APP_ENABLED": "1",
+ "APP_MODEL_NAME": "alexzyqi/Tailor3D-Base-1.0",
+ "APP_PRETRAIN_MODEL_NAME": "zxhezexin/openlrm-mix-base-1.1",
+ "APP_INFER": "./configs/infer-gradio-base.yaml",
+ "APP_TYPE": "infer.lrm",
+ "NUMBA_THREADING_LAYER": 'omp',
+ })
+
+ from openlrm.runners import REGISTRY_RUNNERS
+ from openlrm.runners.infer.base_inferrer import Inferrer
+ InferrerClass : Inferrer = REGISTRY_RUNNERS[os.getenv("APP_TYPE")]
+ with InferrerClass() as inferrer:
+ init_preprocessor()
+ if not bool(os.getenv('SPACE_ID')):
+ from openlrm.utils.proxy import no_proxy
+ demo = no_proxy(demo_openlrm)
+ else:
+ demo = demo_openlrm
+ demo(infer_impl=inferrer.infer_single)
+
+
+if __name__ == '__main__':
+
+ launch_gradio_app()
diff --git a/assets/sample_input/demo/back/armor.png b/assets/sample_input/demo/back/armor.png
new file mode 100644
index 0000000000000000000000000000000000000000..3d4c793f816d58fc0c232513f3be77d7d0508c06
Binary files /dev/null and b/assets/sample_input/demo/back/armor.png differ
diff --git a/assets/sample_input/demo/back/bird.png b/assets/sample_input/demo/back/bird.png
new file mode 100644
index 0000000000000000000000000000000000000000..71a53ea433ea7604e947a7a435ba741238d9b174
Binary files /dev/null and b/assets/sample_input/demo/back/bird.png differ
diff --git a/assets/sample_input/demo/back/bird_brownblue.png b/assets/sample_input/demo/back/bird_brownblue.png
new file mode 100644
index 0000000000000000000000000000000000000000..126188e39e6a665a0f2c0f7c2639a228c75c5f61
Binary files /dev/null and b/assets/sample_input/demo/back/bird_brownblue.png differ
diff --git a/assets/sample_input/demo/back/bird_rainbow.png b/assets/sample_input/demo/back/bird_rainbow.png
new file mode 100644
index 0000000000000000000000000000000000000000..e97ccee2cee23e38d8afb751f0fc5b938df6d6e1
Binary files /dev/null and b/assets/sample_input/demo/back/bird_rainbow.png differ
diff --git a/assets/sample_input/demo/back/bird_whitered.png b/assets/sample_input/demo/back/bird_whitered.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce0d5858db2680cd173d7f5fe88e2273c36ac33d
Binary files /dev/null and b/assets/sample_input/demo/back/bird_whitered.png differ
diff --git a/assets/sample_input/demo/back/boy_astronaut.png b/assets/sample_input/demo/back/boy_astronaut.png
new file mode 100644
index 0000000000000000000000000000000000000000..a7e107adff8d0bf16d1a98e899a5b7814edbcbbf
Binary files /dev/null and b/assets/sample_input/demo/back/boy_astronaut.png differ
diff --git a/assets/sample_input/demo/back/boy_blue.png b/assets/sample_input/demo/back/boy_blue.png
new file mode 100644
index 0000000000000000000000000000000000000000..79faeb2b3cdf0480f66bfb92fcd42710027dd607
Binary files /dev/null and b/assets/sample_input/demo/back/boy_blue.png differ
diff --git a/assets/sample_input/demo/back/boy_chinese_style.png b/assets/sample_input/demo/back/boy_chinese_style.png
new file mode 100644
index 0000000000000000000000000000000000000000..0dbe95df5a9053200cae6f49141b2839adad0fdd
Binary files /dev/null and b/assets/sample_input/demo/back/boy_chinese_style.png differ
diff --git a/assets/sample_input/demo/back/boy_grey_clothes.png b/assets/sample_input/demo/back/boy_grey_clothes.png
new file mode 100644
index 0000000000000000000000000000000000000000..5c0603a2b0724b3b95a088513a141318cdf7a1fc
Binary files /dev/null and b/assets/sample_input/demo/back/boy_grey_clothes.png differ
diff --git a/assets/sample_input/demo/back/house.png b/assets/sample_input/demo/back/house.png
new file mode 100644
index 0000000000000000000000000000000000000000..f729175e7a5c5e776acf6a1138c11d472d8bd617
Binary files /dev/null and b/assets/sample_input/demo/back/house.png differ
diff --git a/assets/sample_input/demo/back/kunkun_law.png b/assets/sample_input/demo/back/kunkun_law.png
new file mode 100644
index 0000000000000000000000000000000000000000..f796d29e1feb8e5a4c60dc57ddf65553c3ffed35
Binary files /dev/null and b/assets/sample_input/demo/back/kunkun_law.png differ
diff --git a/assets/sample_input/demo/back/lego_astronaut.png b/assets/sample_input/demo/back/lego_astronaut.png
new file mode 100644
index 0000000000000000000000000000000000000000..cbeff41be23598658c8d9a383c5d8bfdbbf70a9f
Binary files /dev/null and b/assets/sample_input/demo/back/lego_astronaut.png differ
diff --git a/assets/sample_input/demo/back/lego_green.png b/assets/sample_input/demo/back/lego_green.png
new file mode 100644
index 0000000000000000000000000000000000000000..cbabe9b9002f29fed85921d8c560e0934c71f8e4
Binary files /dev/null and b/assets/sample_input/demo/back/lego_green.png differ
diff --git a/assets/sample_input/demo/back/lego_red.png b/assets/sample_input/demo/back/lego_red.png
new file mode 100644
index 0000000000000000000000000000000000000000..c49d948bd40e1f10a8e31cf0e335e7d2fe4c3d14
Binary files /dev/null and b/assets/sample_input/demo/back/lego_red.png differ
diff --git a/assets/sample_input/demo/back/lego_wizard.png b/assets/sample_input/demo/back/lego_wizard.png
new file mode 100644
index 0000000000000000000000000000000000000000..bd7739fceac0877233dd9f721f1ba36cc260a839
Binary files /dev/null and b/assets/sample_input/demo/back/lego_wizard.png differ
diff --git a/assets/sample_input/demo/back/loopy.png b/assets/sample_input/demo/back/loopy.png
new file mode 100644
index 0000000000000000000000000000000000000000..f16bf1487e85fdc5b843b494e2706220eeaee150
Binary files /dev/null and b/assets/sample_input/demo/back/loopy.png differ
diff --git a/assets/sample_input/demo/back/mario.png b/assets/sample_input/demo/back/mario.png
new file mode 100644
index 0000000000000000000000000000000000000000..bea2bda1b0a792a111a9877486c7d72614220412
Binary files /dev/null and b/assets/sample_input/demo/back/mario.png differ
diff --git a/assets/sample_input/demo/back/marvel_captain.png b/assets/sample_input/demo/back/marvel_captain.png
new file mode 100644
index 0000000000000000000000000000000000000000..ef93ef0b62347bbf8f7b2a3cbc1581e709410c3c
Binary files /dev/null and b/assets/sample_input/demo/back/marvel_captain.png differ
diff --git a/assets/sample_input/demo/back/marvel_ironman.png b/assets/sample_input/demo/back/marvel_ironman.png
new file mode 100644
index 0000000000000000000000000000000000000000..573c584f25b2f6cdea66732d983e7b24f0b63a81
Binary files /dev/null and b/assets/sample_input/demo/back/marvel_ironman.png differ
diff --git a/assets/sample_input/demo/back/marvel_spiderman.png b/assets/sample_input/demo/back/marvel_spiderman.png
new file mode 100644
index 0000000000000000000000000000000000000000..7f232f0e84d2eafbb51e67794513e46edf3171cf
Binary files /dev/null and b/assets/sample_input/demo/back/marvel_spiderman.png differ
diff --git a/assets/sample_input/demo/back/marvel_superman.png b/assets/sample_input/demo/back/marvel_superman.png
new file mode 100644
index 0000000000000000000000000000000000000000..bb06b4e54d46f7b1bd4a66202317466712108fcd
Binary files /dev/null and b/assets/sample_input/demo/back/marvel_superman.png differ
diff --git a/assets/sample_input/demo/back/penguin.png b/assets/sample_input/demo/back/penguin.png
new file mode 100644
index 0000000000000000000000000000000000000000..277f29c5fa3823baf032fea3b3f5c7fd722973d3
Binary files /dev/null and b/assets/sample_input/demo/back/penguin.png differ
diff --git a/assets/sample_input/demo/back/sofa.png b/assets/sample_input/demo/back/sofa.png
new file mode 100644
index 0000000000000000000000000000000000000000..b9afb35bdd489f9284986eeca6eaf1c68047259a
Binary files /dev/null and b/assets/sample_input/demo/back/sofa.png differ
diff --git a/assets/sample_input/demo/front/armor.png b/assets/sample_input/demo/front/armor.png
new file mode 100644
index 0000000000000000000000000000000000000000..892f04db2d34c82a4f2dbeac1c9c4665b76f9d45
Binary files /dev/null and b/assets/sample_input/demo/front/armor.png differ
diff --git a/assets/sample_input/demo/front/bird.png b/assets/sample_input/demo/front/bird.png
new file mode 100644
index 0000000000000000000000000000000000000000..3bae81d9ffd7be1fbc22bbb231978cc985bb9e2b
Binary files /dev/null and b/assets/sample_input/demo/front/bird.png differ
diff --git a/assets/sample_input/demo/front/bird_brownblue.png b/assets/sample_input/demo/front/bird_brownblue.png
new file mode 100644
index 0000000000000000000000000000000000000000..3bae81d9ffd7be1fbc22bbb231978cc985bb9e2b
Binary files /dev/null and b/assets/sample_input/demo/front/bird_brownblue.png differ
diff --git a/assets/sample_input/demo/front/bird_rainbow.png b/assets/sample_input/demo/front/bird_rainbow.png
new file mode 100644
index 0000000000000000000000000000000000000000..3bae81d9ffd7be1fbc22bbb231978cc985bb9e2b
Binary files /dev/null and b/assets/sample_input/demo/front/bird_rainbow.png differ
diff --git a/assets/sample_input/demo/front/bird_whitered.png b/assets/sample_input/demo/front/bird_whitered.png
new file mode 100644
index 0000000000000000000000000000000000000000..3bae81d9ffd7be1fbc22bbb231978cc985bb9e2b
Binary files /dev/null and b/assets/sample_input/demo/front/bird_whitered.png differ
diff --git a/assets/sample_input/demo/front/boy_astronaut.png b/assets/sample_input/demo/front/boy_astronaut.png
new file mode 100644
index 0000000000000000000000000000000000000000..85ee3a2cf0f9b7816a2d6dfbf757dba200a87214
Binary files /dev/null and b/assets/sample_input/demo/front/boy_astronaut.png differ
diff --git a/assets/sample_input/demo/front/boy_blue.png b/assets/sample_input/demo/front/boy_blue.png
new file mode 100644
index 0000000000000000000000000000000000000000..cb853234bcfb8ba07ef6ec3168068478d052a57b
Binary files /dev/null and b/assets/sample_input/demo/front/boy_blue.png differ
diff --git a/assets/sample_input/demo/front/boy_chinese_style.png b/assets/sample_input/demo/front/boy_chinese_style.png
new file mode 100644
index 0000000000000000000000000000000000000000..612176a643ba5d3e5c8605d688de792866ea90ea
Binary files /dev/null and b/assets/sample_input/demo/front/boy_chinese_style.png differ
diff --git a/assets/sample_input/demo/front/boy_grey_clothes.png b/assets/sample_input/demo/front/boy_grey_clothes.png
new file mode 100644
index 0000000000000000000000000000000000000000..2fd30066d890613a2cfffa6faf7a8f0447525e6e
Binary files /dev/null and b/assets/sample_input/demo/front/boy_grey_clothes.png differ
diff --git a/assets/sample_input/demo/front/house.png b/assets/sample_input/demo/front/house.png
new file mode 100644
index 0000000000000000000000000000000000000000..73b3ef6eaa74cf08b2414ff7c73c3113e9f9bf69
Binary files /dev/null and b/assets/sample_input/demo/front/house.png differ
diff --git a/assets/sample_input/demo/front/kunkun_law.png b/assets/sample_input/demo/front/kunkun_law.png
new file mode 100644
index 0000000000000000000000000000000000000000..232d6b501d7aec2f4e8191ac2ea7dbbbdf5aa7e4
Binary files /dev/null and b/assets/sample_input/demo/front/kunkun_law.png differ
diff --git a/assets/sample_input/demo/front/lego_astronaut.png b/assets/sample_input/demo/front/lego_astronaut.png
new file mode 100644
index 0000000000000000000000000000000000000000..d97c780ea4d4d7dd9ea69d3fb1bcecd526d9f761
Binary files /dev/null and b/assets/sample_input/demo/front/lego_astronaut.png differ
diff --git a/assets/sample_input/demo/front/lego_green.png b/assets/sample_input/demo/front/lego_green.png
new file mode 100644
index 0000000000000000000000000000000000000000..a324af0d7a515a501866e4b37f24253bf70c1a48
Binary files /dev/null and b/assets/sample_input/demo/front/lego_green.png differ
diff --git a/assets/sample_input/demo/front/lego_red.png b/assets/sample_input/demo/front/lego_red.png
new file mode 100644
index 0000000000000000000000000000000000000000..6b32f7d1241c096b41323c4db0a0f655458ee855
Binary files /dev/null and b/assets/sample_input/demo/front/lego_red.png differ
diff --git a/assets/sample_input/demo/front/lego_wizard.png b/assets/sample_input/demo/front/lego_wizard.png
new file mode 100644
index 0000000000000000000000000000000000000000..396fd41b11200f2e85c6388a251beba6e1441507
Binary files /dev/null and b/assets/sample_input/demo/front/lego_wizard.png differ
diff --git a/assets/sample_input/demo/front/loopy.png b/assets/sample_input/demo/front/loopy.png
new file mode 100644
index 0000000000000000000000000000000000000000..ffdf9844cded0267947e810636c77de6b390b6c4
Binary files /dev/null and b/assets/sample_input/demo/front/loopy.png differ
diff --git a/assets/sample_input/demo/front/mario.png b/assets/sample_input/demo/front/mario.png
new file mode 100644
index 0000000000000000000000000000000000000000..54780d4f3e741e8cad7ca529a64e750a392a159f
Binary files /dev/null and b/assets/sample_input/demo/front/mario.png differ
diff --git a/assets/sample_input/demo/front/marvel_captain.png b/assets/sample_input/demo/front/marvel_captain.png
new file mode 100644
index 0000000000000000000000000000000000000000..753328918caabdda3d61fe044f4df3a6ef447b7f
Binary files /dev/null and b/assets/sample_input/demo/front/marvel_captain.png differ
diff --git a/assets/sample_input/demo/front/marvel_ironman.png b/assets/sample_input/demo/front/marvel_ironman.png
new file mode 100644
index 0000000000000000000000000000000000000000..fb83948dd80b791fdf5032519b5b95c9d467bf6b
Binary files /dev/null and b/assets/sample_input/demo/front/marvel_ironman.png differ
diff --git a/assets/sample_input/demo/front/marvel_spiderman.png b/assets/sample_input/demo/front/marvel_spiderman.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7aa9ce26669ba148bea42f0886d1a6aac4ed316
Binary files /dev/null and b/assets/sample_input/demo/front/marvel_spiderman.png differ
diff --git a/assets/sample_input/demo/front/marvel_superman.png b/assets/sample_input/demo/front/marvel_superman.png
new file mode 100644
index 0000000000000000000000000000000000000000..32bbce4fc8c6f9a86203a74afda7cfe3bfea892c
Binary files /dev/null and b/assets/sample_input/demo/front/marvel_superman.png differ
diff --git a/assets/sample_input/demo/front/penguin.png b/assets/sample_input/demo/front/penguin.png
new file mode 100644
index 0000000000000000000000000000000000000000..2e94acf9b089832e61d6fd6447822aee56eaa424
Binary files /dev/null and b/assets/sample_input/demo/front/penguin.png differ
diff --git a/assets/sample_input/demo/front/sofa.png b/assets/sample_input/demo/front/sofa.png
new file mode 100644
index 0000000000000000000000000000000000000000..0baaa67d61cab614baa43551c98f2b191f33fbe6
Binary files /dev/null and b/assets/sample_input/demo/front/sofa.png differ
diff --git a/configs/accelerate-train-4gpus.yaml b/configs/accelerate-train-4gpus.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..de1bb7c6b6ec93ff568771bc75ca01a3221e43b1
--- /dev/null
+++ b/configs/accelerate-train-4gpus.yaml
@@ -0,0 +1,17 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+# gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+main_process_port: 34567
diff --git a/configs/accelerate-train.yaml b/configs/accelerate-train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0bd37d44db1f3b98c6d06bcd655ece22fc8a38b0
--- /dev/null
+++ b/configs/accelerate-train.yaml
@@ -0,0 +1,17 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+main_process_port: 35567
diff --git a/configs/all-base-2sides.yaml b/configs/all-base-2sides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..127898437d98c410548a110f83b6e904ea3c2eb0
--- /dev/null
+++ b/configs/all-base-2sides.yaml
@@ -0,0 +1,115 @@
+
+experiment:
+ type: lrm
+ seed: 42
+ parent: gobjaverse-2sides-base
+ child: 0428_conv_e10
+
+model:
+ camera_embed_dim: 1024
+ rendering_samples_per_ray: 96
+ transformer_dim: 768
+ transformer_layers: 12
+ transformer_heads: 12
+ triplane_low_res: 32
+ triplane_high_res: 64 # useless?
+ triplane_dim: 48
+ encoder_type: dinov2
+ encoder_model_name: dinov2_vitb14_reg
+ encoder_feat_dim: 768
+ encoder_freeze: false
+ model_lora_rank: 4
+ conv_fuse: True
+
+dataset:
+ subsets:
+ - name: gobjaverse_delete_tb
+ root_dirs:
+ ['data/data_gobjaverse_delete_tb']
+ meta_path:
+ train: data/data_gobjaverse_delete_tb/train.json
+ val: data/data_gobjaverse_delete_tb/val.json
+ sample_rate: 1.0
+ sample_side_views: 3
+ source_image_res: 336
+ render_image:
+ low: 96
+ high: 288
+ region: 96
+ normalize_camera: true
+ normed_dist_to_center: auto
+ num_train_workers: 4
+ num_val_workers: 2
+ pin_mem: true
+
+train:
+ mixed_precision: fp16 # REPLACE THIS BASED ON GPU TYPE
+ find_unused_parameters: false
+ loss:
+ pixel_weight: 1.0
+ perceptual_weight: 1.0
+ tv_weight: 5e-4
+ optim:
+ lr: 4e-4 # most important.
+ weight_decay: 0.05
+ beta1: 0.9
+ beta2: 0.95
+ clip_grad_norm: 1.0
+ scheduler:
+ type: cosine
+ warmup_real_iters: 3000
+ batch_size: 8 # REPLACE THIS (PER GPU)
+ accum_steps: 1 # REPLACE THIS
+ epochs: 10 # REPLACE THIS
+ debug_global_steps: null
+
+val:
+ batch_size: 4
+ global_step_period: 1000
+ debug_batches: null
+
+saver:
+ auto_resume: true
+ load_model:
+ type: hugging_face
+ url: zxhezexin/openlrm-mix-base-1.1/model.safetensors
+ checkpoint_root: ./exps/checkpoints
+ checkpoint_global_steps: 1000
+ checkpoint_keep_level: 5
+ load_model_func_kwargs:
+ strict: False
+
+logger:
+ stream_level: WARNING
+ log_level: INFO
+ log_root: ./exps/logs
+ tracker_root: ./exps/trackers
+ enable_profiler: false
+ trackers:
+ - tensorboard
+ image_monitor:
+ train_global_steps: 100
+ samples_per_log: 4
+
+compile:
+ suppress_errors: true
+ print_specializations: true
+ disable: true
+
+inferrer:
+ logger: INFO
+ hugging_face: False
+ iteration: 3330
+ image_format: True
+
+ source_size: 336
+ source_cam_dist: 2.0
+ render_size: 288
+ render_views: 16
+ render_fps: 40
+ frame_size: 2
+ mesh_size: 384
+ mesh_thres: 3.0
+
+convert:
+ global_step:
\ No newline at end of file
diff --git a/configs/all-large-2sides.yaml b/configs/all-large-2sides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c94397086ad243373bf0b64c8c2a8f03675bf0ed
--- /dev/null
+++ b/configs/all-large-2sides.yaml
@@ -0,0 +1,115 @@
+
+experiment:
+ type: lrm
+ seed: 42
+ parent: gobjaverse-2sides-large
+ child: 0428_conv_e10
+
+model:
+ camera_embed_dim: 1024
+ rendering_samples_per_ray: 128
+ transformer_dim: 1024
+ transformer_layers: 16
+ transformer_heads: 16
+ triplane_low_res: 32 # always 32.
+ triplane_high_res: 64 # useless?
+ triplane_dim: 80
+ encoder_type: dinov2
+ encoder_model_name: dinov2_vitb14_reg
+ encoder_feat_dim: 768
+ encoder_freeze: false
+ model_lora_rank: 4
+ conv_fuse: True
+
+dataset:
+ subsets:
+ - name: gobjaverse_delete_tb
+ root_dirs:
+ ['data/data_gobjaverse_delete_tb']
+ meta_path:
+ train: data/data_gobjaverse_delete_tb/train.json
+ val: data/data_gobjaverse_delete_tb/val.json
+ sample_rate: 1.0
+ sample_side_views: 3
+ source_image_res: 336
+ render_image:
+ low: 128
+ high: 384
+ region: 128
+ normalize_camera: true
+ normed_dist_to_center: auto
+ num_train_workers: 4
+ num_val_workers: 2
+ pin_mem: true
+
+train:
+ mixed_precision: fp16 # REPLACE THIS BASED ON GPU TYPE
+ find_unused_parameters: false
+ loss:
+ pixel_weight: 1.0
+ perceptual_weight: 1.0
+ tv_weight: 5e-4
+ optim:
+ lr: 4e-4 # most important.
+ weight_decay: 0.05
+ beta1: 0.9
+ beta2: 0.95
+ clip_grad_norm: 1.0
+ scheduler:
+ type: cosine
+ warmup_real_iters: 3000
+ batch_size: 2 # REPLACE THIS (PER GPU)
+ accum_steps: 1 # REPLACE THIS
+ epochs: 10 # REPLACE THIS
+ debug_global_steps: null
+
+val:
+ batch_size: 4
+ global_step_period: 1000
+ debug_batches: null
+
+saver:
+ auto_resume: true
+ load_model:
+ type: hugging_face
+ url: zxhezexin/openlrm-mix-large-1.1/model.safetensors
+ checkpoint_root: ./exps/checkpoints
+ checkpoint_global_steps: 1000
+ checkpoint_keep_level: 5
+ load_model_func_kwargs:
+ strict: False
+
+logger:
+ stream_level: WARNING
+ log_level: INFO
+ log_root: ./exps/logs
+ tracker_root: ./exps/trackers
+ enable_profiler: false
+ trackers:
+ - tensorboard
+ image_monitor:
+ train_global_steps: 100
+ samples_per_log: 4
+
+compile:
+ suppress_errors: true
+ print_specializations: true
+ disable: true
+
+inferrer:
+ logger: INFO
+ hugging_face: False
+ iteration: 13340
+ image_format: True
+
+ source_size: 448
+ source_cam_dist: 2.0
+ render_size: 384
+ render_views: 16
+ render_fps: 40
+ frame_size: 2
+ mesh_size: 1024
+ mesh_thres: 1
+
+convert:
+ global_step:
\ No newline at end of file
diff --git a/configs/all-small-2sides.yaml b/configs/all-small-2sides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7a90c858a7213e1a97650ee23dc80b6bea886fff
--- /dev/null
+++ b/configs/all-small-2sides.yaml
@@ -0,0 +1,115 @@
+
+experiment:
+ type: lrm
+ seed: 42
+ parent: gobjaverse-2sides-small
+ child: 0428_conv
+
+model:
+ camera_embed_dim: 1024
+ rendering_samples_per_ray: 96
+ transformer_dim: 512
+ transformer_layers: 12
+ transformer_heads: 8
+ triplane_low_res: 32
+ triplane_high_res: 64
+ triplane_dim: 32
+ encoder_type: dinov2
+ encoder_model_name: dinov2_vits14_reg
+ encoder_feat_dim: 384
+ encoder_freeze: false
+ model_lora_rank: 4
+ conv_fuse: True
+
+dataset:
+ subsets:
+ - name: gobjaverse_delete_tb
+ root_dirs:
+ ['data/data_gobjaverse_delete_tb']
+ meta_path:
+ train: data/data_gobjaverse_delete_tb/train.json
+ val: data/data_gobjaverse_delete_tb/val.json
+ sample_rate: 1.0
+ sample_side_views: 3
+ source_image_res: 224
+ render_image:
+ low: 64
+ high: 192
+ region: 64
+ normalize_camera: true
+ normed_dist_to_center: auto
+ num_train_workers: 4
+ num_val_workers: 2
+ pin_mem: true
+
+train:
+ mixed_precision: fp16 # REPLACE THIS BASED ON GPU TYPE
+ find_unused_parameters: false
+ loss:
+ pixel_weight: 1.0
+ perceptual_weight: 1.0
+ tv_weight: 5e-4
+ optim:
+ lr: 4e-4
+ weight_decay: 0.05
+ beta1: 0.9
+ beta2: 0.95
+ clip_grad_norm: 1.0
+ scheduler:
+ type: cosine
+ warmup_real_iters: 3000
+ batch_size: 16 # REPLACE THIS (PER GPU)
+ accum_steps: 1 # REPLACE THIS
+ epochs: 10 # REPLACE THIS
+ debug_global_steps: null
+
+val:
+ batch_size: 4
+ global_step_period: 1000
+ debug_batches: null
+
+saver:
+ auto_resume: true
+ load_model:
+ type: hugging_face
+ url: zxhezexin/openlrm-mix-small-1.1/model.safetensors
+ checkpoint_root: ./exps/checkpoints
+ checkpoint_global_steps: 1000
+ checkpoint_keep_level: 5
+ load_model_func_kwargs:
+ strict: False
+
+logger:
+ stream_level: WARNING
+ log_level: INFO
+ log_root: ./exps/logs
+ tracker_root: ./exps/trackers
+ enable_profiler: false
+ trackers:
+ - tensorboard
+ image_monitor:
+ train_global_steps: 100
+ samples_per_log: 4
+
+compile:
+ suppress_errors: true
+ print_specializations: true
+ disable: true
+
+inferrer:
+ logger: INFO
+ hugging_face: False
+ iteration: 1660
+ image_format: True
+
+ source_size: 224
+ source_cam_dist: 2.0
+ render_size: 192
+ render_views: 16
+ render_fps: 40
+ frame_size: 2
+ mesh_size: 384
+ mesh_thres: 3.0
+
+convert:
+ global_step:
\ No newline at end of file
diff --git a/configs/infer-gradio-base.yaml b/configs/infer-gradio-base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d5c0aeed8b3511ea66058f978d977b02fc0b121
--- /dev/null
+++ b/configs/infer-gradio-base.yaml
@@ -0,0 +1,16 @@
+double_sided: True
+
+model:
+ conv_fuse: True
+
+inferrer:
+ hugging_face: True
+ source_size: 336
+ render_size: 288
+ render_views: 100
+ render_fps: 25
+ frame_size: 2
+ mesh_size: 384
+ mesh_thres: 3.0
+ double_sided: True
+ image_format: False
\ No newline at end of file
diff --git a/configs/infer-gradio-large.yaml b/configs/infer-gradio-large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0fadc8e0bd3cf09e4250033bab74b126f1c1e616
--- /dev/null
+++ b/configs/infer-gradio-large.yaml
@@ -0,0 +1,16 @@
+double_sided: True
+
+model:
+ conv_fuse: True
+
+inferrer:
+ hugging_face: True
+ source_size: 448
+ render_size: 384
+ render_views: 100
+ render_fps: 25
+ frame_size: 2
+ mesh_size: 1024
+ mesh_thres: 3.0
+ double_sided: True
+ image_format: False
diff --git a/openlrm/__init__.py b/openlrm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/openlrm/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/openlrm/datasets/__init__.py b/openlrm/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..323127c7d93f0a57f90cc8649ee2a67b6b630762
--- /dev/null
+++ b/openlrm/datasets/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .mixer import MixerDataset
diff --git a/openlrm/datasets/back_transform/back_transform.py b/openlrm/datasets/back_transform/back_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..c498177e969e74915cbb4b03fc031054f3ea7efa
--- /dev/null
+++ b/openlrm/datasets/back_transform/back_transform.py
@@ -0,0 +1,24 @@
+import torch
+import torchvision.transforms as transforms
+
+# Add gauusian noise
+class AddGaussianNoise:
+ def __init__(self, mean=0., std=1.):
+ self.std = std
+ self.mean = mean
+
+ def __call__(self, img):
+ return img + torch.randn(img.size()) * self.std + self.mean
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
+
+
+def transform_back_image():
+ return transforms.Compose([
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
+ transforms.RandomRotation((-5, 5)),
+ transforms.Pad(padding=10, fill=0, padding_mode='constant'),
+ transforms.ToTensor(),
+ AddGaussianNoise(0., 0.1)
+ ])
diff --git a/openlrm/datasets/base.py b/openlrm/datasets/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f327fb147cafbfeaa7455fd8e8f35c9591b2d018
--- /dev/null
+++ b/openlrm/datasets/base.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+import json
+import numpy as np
+import torch
+from PIL import Image
+from megfile import smart_open, smart_path_join, smart_exists
+
+
+class BaseDataset(torch.utils.data.Dataset, ABC):
+ def __init__(self, root_dirs: list[str], meta_path: str):
+ super().__init__()
+ self.root_dirs = root_dirs
+ self.uids = self._load_uids(meta_path)
+
+ def __len__(self):
+ return len(self.uids)
+
+ @abstractmethod
+ def inner_get_item(self, idx):
+ pass
+
+ def __getitem__(self, idx):
+ try:
+ return self.inner_get_item(idx)
+ except Exception as e:
+ print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}")
+ # return self.__getitem__(idx+1)
+ raise e
+
+ @staticmethod
+ def _load_uids(meta_path: str):
+ # meta_path is a json file
+ with open(meta_path, 'r') as f:
+ uids = json.load(f)
+ return uids
+
+ @staticmethod
+ def _load_rgba_image(file_path, bg_color: float = 1.0):
+ ''' Load and blend RGBA image to RGB with certain background, 0-1 scaled '''
+ rgba = np.array(Image.open(smart_open(file_path, 'rb')))
+ rgba = torch.from_numpy(rgba).float() / 255.0
+ rgba = rgba.permute(2, 0, 1).unsqueeze(0)
+ rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :])
+ rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...])
+ return rgb
+
+ @staticmethod
+ def _locate_datadir(root_dirs, uid, locator: str):
+ for root_dir in root_dirs:
+ datadir = smart_path_join(root_dir, uid, locator)
+ if smart_exists(datadir):
+ return root_dir
+ raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}")
diff --git a/openlrm/datasets/cam_utils.py b/openlrm/datasets/cam_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..70653ae2a7f612714f729c73f45e826109b7e0ff
--- /dev/null
+++ b/openlrm/datasets/cam_utils.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+import torch
+
+"""
+R: (N, 3, 3)
+T: (N, 3)
+E: (N, 4, 4)
+vector: (N, 3)
+"""
+
+
+def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor):
+ """
+ Compose the standard form extrinsic matrix from R and T.
+ Batched I/O.
+ """
+ RT = torch.cat((R, T.unsqueeze(-1)), dim=-1)
+ return compose_extrinsic_RT(RT)
+
+
+def compose_extrinsic_RT(RT: torch.Tensor):
+ """
+ Compose the standard form extrinsic matrix from RT.
+ Batched I/O.
+ """
+ return torch.cat([
+ RT,
+ torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
+ ], dim=1)
+
+
+def decompose_extrinsic_R_T(E: torch.Tensor):
+ """
+ Decompose the standard extrinsic matrix into R and T.
+ Batched I/O.
+ """
+ RT = decompose_extrinsic_RT(E)
+ return RT[:, :, :3], RT[:, :, 3]
+
+
+def decompose_extrinsic_RT(E: torch.Tensor):
+ """
+ Decompose the standard extrinsic matrix into RT.
+ Batched I/O.
+ """
+ return E[:, :3, :]
+
+
+def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False):
+ assert normed_dist_to_center is not None
+ pivotal_pose = compose_extrinsic_RT(poses[:1])
+ dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \
+ if normed_dist_to_center == 'auto' else normed_dist_to_center
+
+ # compute camera norm (new version)
+ canonical_camera_extrinsics = torch.tensor([[
+ [1, 0, 0, 0],
+ [0, 0, -1, -dist_to_center],
+ [0, 1, 0, 0],
+ [0, 0, 0, 1],
+ ]], dtype=torch.float32)
+ pivotal_pose_inv = torch.inverse(pivotal_pose)
+ camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)
+
+ # normalize all views
+ poses = compose_extrinsic_RT(poses)
+ poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
+ poses = decompose_extrinsic_RT(poses)
+
+ if ret_transform:
+ return poses, camera_norm_matrix.squeeze(dim=0)
+ return poses
+
+
+def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
+ """
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
+ Return batched fx, fy, cx, cy
+ """
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
+ cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
+ width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
+ fx, fy = fx / width, fy / height
+ cx, cy = cx / width, cy / height
+ return fx, fy, cx, cy
+
+
+def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
+ """
+ RT: (N, 3, 4)
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
+ """
+ fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
+ return torch.cat([
+ RT.reshape(-1, 12),
+ fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
+ ], dim=-1)
+
+
+def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
+ """
+ RT: (N, 3, 4)
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
+ """
+ E = compose_extrinsic_RT(RT)
+ fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
+ I = torch.stack([
+ torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
+ torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
+ torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
+ ], dim=1)
+ return torch.cat([
+ E.reshape(-1, 16),
+ I.reshape(-1, 9),
+ ], dim=-1)
+
+
+def center_looking_at_camera_pose(
+ camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None,
+ device: torch.device = torch.device('cpu'),
+ ):
+ """
+ camera_position: (M, 3)
+ look_at: (3)
+ up_world: (3)
+ return: (M, 3, 4)
+ """
+ # by default, looking at the origin and world up is pos-z
+ if look_at is None:
+ look_at = torch.tensor([0, 0, 0], dtype=torch.float32, device=device)
+ if up_world is None:
+ up_world = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
+ look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
+ up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
+
+ z_axis = camera_position - look_at
+ z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
+ x_axis = torch.cross(up_world, z_axis)
+ x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
+ y_axis = torch.cross(z_axis, x_axis)
+ y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
+ extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
+ return extrinsics
+
+
+def surrounding_views_linspace(n_views: int, radius: float = 2.0, height: float = 0.8, device: torch.device = torch.device('cpu')):
+ """
+ n_views: number of surrounding views
+ radius: camera dist to center
+ height: height of the camera
+ return: (M, 3, 4)
+ """
+ assert n_views > 0
+ assert radius > 0
+
+ theta = torch.linspace(-torch.pi / 2, 3 * torch.pi / 2, n_views, device=device)
+ projected_radius = math.sqrt(radius ** 2 - height ** 2)
+ x = torch.cos(theta) * projected_radius
+ y = torch.sin(theta) * projected_radius
+ z = torch.full((n_views,), height, device=device)
+
+ camera_positions = torch.stack([x, y, z], dim=1)
+ extrinsics = center_looking_at_camera_pose(camera_positions, device=device)
+
+ return extrinsics
+
+
+def create_intrinsics(
+ f: float,
+ c: float = None, cx: float = None, cy: float = None,
+ w: float = 1., h: float = 1.,
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = torch.device('cpu'),
+ ):
+ """
+ return: (3, 2)
+ """
+ fx = fy = f
+ if c is not None:
+ assert cx is None and cy is None, "c and cx/cy cannot be used together"
+ cx = cy = c
+ else:
+ assert cx is not None and cy is not None, "cx/cy must be provided when c is not provided"
+ fx, fy, cx, cy, w, h = fx/w, fy/h, cx/w, cy/h, 1., 1.
+ intrinsics = torch.tensor([
+ [fx, fy],
+ [cx, cy],
+ [w, h],
+ ], dtype=dtype, device=device)
+ return intrinsics
diff --git a/openlrm/datasets/gobjaverse.py b/openlrm/datasets/gobjaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e1997a6e658c06324bdbb6e19e218af30d4014
--- /dev/null
+++ b/openlrm/datasets/gobjaverse.py
@@ -0,0 +1,180 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from typing import Union
+import random
+import numpy as np
+import torch
+from megfile import smart_path_join, smart_open
+
+from .cam_utils import build_camera_standard, build_camera_principle, camera_normalization_objaverse
+from ..utils.proxy import no_proxy
+from .objaverse import ObjaverseDataset
+from .back_transform.back_transform import transform_back_image
+
+from PIL import Image
+from torchvision import transforms
+
+__all__ = ['GobjaverseDataset']
+
+def opposite_view(i):
+ if 0 <= i <= 24:
+ return (i + 12) % 24
+ elif 27 <= i <= 39:
+ return ((i - 27) + 6) % 12 + 27
+ else:
+ raise ValueError("Input number must be between 0-24 or 27-39.")
+
+def get_random_views(rgba_dir, num_views=4):
+ all_files = [f for f in os.listdir(rgba_dir) if f.endswith('.png')]
+ view_numbers = [int(os.path.splitext(f)[0]) for f in all_files]
+ selected_views = random.sample(view_numbers, num_views)
+ return np.array(selected_views)
+
+class GobjaverseDataset(ObjaverseDataset):
+
+ def __init__(self, root_dirs: list[str], meta_path: str,
+ sample_side_views: int,
+ render_image_res_low: int, render_image_res_high: int, render_region_size: int,
+ source_image_res: int, normalize_camera: bool,
+ normed_dist_to_center: Union[float, str] = None, num_all_views: int = 32):
+ super().__init__(
+ root_dirs, meta_path,
+ sample_side_views,
+ render_image_res_low,
+ render_image_res_high,
+ render_region_size,
+ source_image_res,
+ normalize_camera,
+ normed_dist_to_center,
+ num_all_views,
+ )
+
+ self.back_transforms = transform_back_image()
+
+ # This is for gobjaverse and objaverse_mengchen
+ @staticmethod
+ def _load_pose_txt(file_path): # load .txt #!!!
+ with open(file_path, 'r') as file:
+ lines = file.readlines()
+ pose_data = np.array([list(map(float, line.split())) for line in lines], dtype=np.float32)
+ pose = torch.from_numpy(pose_data).reshape(4, 4) # [1. 16] -> [4, 4] -> [3, 4]
+ opengl2opencv = np.array([
+ [1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, -1, 0],
+ [0, 0, 0, 1]
+ ], dtype=np.float32)
+ # This is the camera pose in OpenCV format.
+ pose = np.matmul(pose, opengl2opencv)
+ return pose[:3, :] # [4, 4] -> [3, 4]
+
+ @staticmethod
+ def _load_rgba_image_transform(file_path, bg_color: float = 1.0, extra_transforms=None): #!!!
+ ''' Load and blend RGBA image to RGB with certain background, 0-1 scaled '''
+ rgba = np.array(Image.open(smart_open(file_path, 'rb')) ) # (512, 512, 4)
+ rgba = torch.from_numpy(rgba).float() / 255.0
+ rgba = rgba.permute(2, 0, 1).unsqueeze(0)
+ rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :])
+ if extra_transforms is not None:
+ rgb = extra_transforms(
+ transforms.ToPILImage()(rgb.squeeze())
+ ).unsqueeze(0)
+ return rgb # [1, 3, 512, 512]
+
+ @no_proxy
+ def inner_get_item(self, idx):
+ """
+ Loaded contents:
+ rgbs: [M, 3, H, W]
+ poses: [M, 3, 4], [R|t]
+ intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]]
+ """
+ uid = self.uids[idx]
+ root_dir = self._locate_datadir(self.root_dirs, uid, locator="pose")
+
+ pose_dir = os.path.join(root_dir, uid, 'pose')
+ rgba_dir = os.path.join(root_dir, uid, 'rgb')
+
+ # only one intrinsics
+ intrinsics = torch.tensor([[384, 384], [256, 256], [512, 512]], dtype=torch.float)
+
+ # sample views (incl. source view and side views)
+ sample_views = get_random_views(rgba_dir, num_views=self.sample_side_views)
+ source_image_view_back = opposite_view(sample_views[0])
+ sample_views = np.insert(sample_views, 1, source_image_view_back)
+
+ poses, rgbs, bg_colors = [], [], []
+ source_image = None
+ for view in sample_views:
+ pose_path = smart_path_join(pose_dir, f'{view:03d}.txt')
+ rgba_path = smart_path_join(rgba_dir, f'{view:03d}.png')
+ pose = self._load_pose_txt(pose_path) #!!!
+ bg_color = random.choice([0.0, 0.5, 1.0])
+ rgb = self._load_rgba_image(rgba_path, bg_color=bg_color)
+ poses.append(pose)
+ rgbs.append(rgb)
+ bg_colors.append(bg_color)
+ if source_image is None:
+ source_image = self._load_rgba_image(rgba_path, bg_color=1.0)
+ assert source_image is not None, "Really bad luck!"
+ poses = torch.stack(poses, dim=0)
+ rgbs = torch.cat(rgbs, dim=0)
+
+ #!!! lora for the backview
+ source_image_back = self._load_rgba_image_transform(smart_path_join(rgba_dir, f'{sample_views[1]:03d}.png'), bg_color=bg_color)
+
+ if self.normalize_camera:
+ poses = camera_normalization_objaverse(self.normed_dist_to_center, poses)
+
+ # build source and target camera features
+ source_camera = build_camera_principle(poses[:1], intrinsics.unsqueeze(0)).squeeze(0)
+ render_camera = build_camera_standard(poses, intrinsics.repeat(poses.shape[0], 1, 1))
+
+ # adjust source image resolution
+ source_image = torch.nn.functional.interpolate(
+ source_image, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0)
+ source_image = torch.clamp(source_image, 0, 1)
+
+ #!!! adjust source_image_back resolution
+ source_image_back = torch.nn.functional.interpolate(
+ source_image_back, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0)
+ source_image_back = torch.clamp(source_image_back, 0, 1)
+
+ # adjust render image resolution and sample intended rendering region
+ render_image_res = np.random.randint(self.render_image_res_low, self.render_image_res_high + 1)
+ render_image = torch.nn.functional.interpolate(
+ rgbs, size=(render_image_res, render_image_res), mode='bicubic', align_corners=True)
+ render_image = torch.clamp(render_image, 0, 1)
+ anchors = torch.randint(
+ 0, render_image_res - self.render_region_size + 1, size=(self.sample_side_views + 1, 2))
+ crop_indices = torch.arange(0, self.render_region_size, device=render_image.device)
+ index_i = (anchors[:, 0].unsqueeze(1) + crop_indices).view(-1, self.render_region_size, 1)
+ index_j = (anchors[:, 1].unsqueeze(1) + crop_indices).view(-1, 1, self.render_region_size)
+ batch_indices = torch.arange(self.sample_side_views + 1, device=render_image.device).view(-1, 1, 1)
+ cropped_render_image = render_image[batch_indices, :, index_i, index_j].permute(0, 3, 1, 2)
+
+ return {
+ 'uid': uid,
+ 'source_camera': source_camera,
+ 'render_camera': render_camera,
+ 'source_image': source_image,
+ 'render_image': cropped_render_image,
+ 'source_image_back': source_image_back, #!!!
+ 'render_anchors': anchors,
+ 'render_full_resolutions': torch.tensor([[render_image_res]], dtype=torch.float32).repeat(self.sample_side_views + 1, 1),
+ 'render_bg_colors': torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1),
+ }
diff --git a/openlrm/datasets/mixer.py b/openlrm/datasets/mixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8455abf54be7884928fb102dbd610aa4f927b634
--- /dev/null
+++ b/openlrm/datasets/mixer.py
@@ -0,0 +1,72 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from functools import partial
+import torch
+
+__all__ = ['MixerDataset']
+
+
+class MixerDataset(torch.utils.data.Dataset):
+
+ def __init__(self,
+ split: str,
+ subsets: list[dict],
+ **dataset_kwargs,
+ ):
+ self.subsets = [
+ self._dataset_fn(subset, split)(**dataset_kwargs)
+ for subset in subsets
+ ]
+ self.virtual_lens = [
+ math.ceil(subset_config['sample_rate'] * len(subset_obj))
+ for subset_config, subset_obj in zip(subsets, self.subsets)
+ ]
+
+ @staticmethod
+ def _dataset_fn(subset_config: dict, split: str):
+ name = subset_config['name']
+
+ dataset_cls = None
+ if name == "objaverse":
+ from .objaverse import ObjaverseDataset
+ dataset_cls = ObjaverseDataset
+ elif name == 'gobjaverse_delete_tb':
+ from .gobjaverse import GobjaverseDataset
+ dataset_cls = GobjaverseDataset
+ # elif name == 'mvimgnet':
+ # from .mvimgnet import MVImgNetDataset
+ # dataset_cls = MVImgNetDataset
+ else:
+ raise NotImplementedError(f"Dataset {name} not implemented")
+
+ return partial(
+ dataset_cls,
+ root_dirs=subset_config['root_dirs'],
+ meta_path=subset_config['meta_path'][split],
+ )
+
+ def __len__(self):
+ return sum(self.virtual_lens)
+
+ def __getitem__(self, idx):
+ subset_idx = 0
+ virtual_idx = idx
+ while virtual_idx >= self.virtual_lens[subset_idx]:
+ virtual_idx -= self.virtual_lens[subset_idx]
+ subset_idx += 1
+ real_idx = virtual_idx % len(self.subsets[subset_idx])
+ return self.subsets[subset_idx][real_idx]
diff --git a/openlrm/datasets/objaverse.py b/openlrm/datasets/objaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..728a5203b3f6b1046aa7ac126c5e86e26fd90e07
--- /dev/null
+++ b/openlrm/datasets/objaverse.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from typing import Union
+import random
+import numpy as np
+import torch
+from megfile import smart_path_join, smart_open
+
+from .base import BaseDataset
+from .cam_utils import build_camera_standard, build_camera_principle, camera_normalization_objaverse
+from ..utils.proxy import no_proxy
+
+__all__ = ['ObjaverseDataset']
+
+
+class ObjaverseDataset(BaseDataset):
+
+ def __init__(self, root_dirs: list[str], meta_path: str,
+ sample_side_views: int,
+ render_image_res_low: int, render_image_res_high: int, render_region_size: int,
+ source_image_res: int, normalize_camera: bool,
+ normed_dist_to_center: Union[float, str] = None, num_all_views: int = 32):
+ super().__init__(root_dirs, meta_path)
+ self.sample_side_views = sample_side_views # 3
+ self.render_image_res_low = render_image_res_low # 64
+ self.render_image_res_high = render_image_res_high # 192
+ self.render_region_size = render_region_size # 64
+ self.source_image_res = source_image_res # 224s
+ self.normalize_camera = normalize_camera # True
+ self.normed_dist_to_center = normed_dist_to_center # 'auto'
+ self.num_all_views = num_all_views
+
+ @staticmethod
+ def _load_pose(file_path):
+ pose = np.load(smart_open(file_path, 'rb'))
+ pose = torch.from_numpy(pose).float()
+ return pose
+
+ @no_proxy
+ def inner_get_item(self, idx):
+ """
+ Loaded contents:
+ rgbs: [M, 3, H, W]
+ poses: [M, 3, 4], [R|t]
+ intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]]
+ """
+ uid = self.uids[idx]
+ root_dir = self._locate_datadir(self.root_dirs, uid, locator="intrinsics.npy")
+
+ pose_dir = os.path.join(root_dir, uid, 'pose')
+ rgba_dir = os.path.join(root_dir, uid, 'rgba')
+ intrinsics_path = os.path.join(root_dir, uid, 'intrinsics.npy')
+
+ # load intrinsics
+ intrinsics = np.load(smart_open(intrinsics_path, 'rb'))
+ intrinsics = torch.from_numpy(intrinsics).float()
+
+ # sample views (incl. source view and side views)
+ sample_views = np.random.choice(range(self.num_all_views), self.sample_side_views + 1, replace=False)
+ poses, rgbs, bg_colors = [], [], []
+ source_image = None
+ for view in sample_views:
+ pose_path = smart_path_join(pose_dir, f'{view:03d}.npy')
+ rgba_path = smart_path_join(rgba_dir, f'{view:03d}.png')
+ pose = self._load_pose(pose_path)
+ bg_color = random.choice([0.0, 0.5, 1.0])
+ rgb = self._load_rgba_image(rgba_path, bg_color=bg_color)
+ poses.append(pose)
+ rgbs.append(rgb)
+ bg_colors.append(bg_color)
+ if source_image is None:
+ source_image = self._load_rgba_image(rgba_path, bg_color=1.0)
+ assert source_image is not None, "Really bad luck!"
+ poses = torch.stack(poses, dim=0)
+ rgbs = torch.cat(rgbs, dim=0)
+
+ if self.normalize_camera:
+ poses = camera_normalization_objaverse(self.normed_dist_to_center, poses)
+
+ # build source and target camera features
+ source_camera = build_camera_principle(poses[:1], intrinsics.unsqueeze(0)).squeeze(0)
+ render_camera = build_camera_standard(poses, intrinsics.repeat(poses.shape[0], 1, 1))
+
+ # adjust source image resolution
+ source_image = torch.nn.functional.interpolate(
+ source_image, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0)
+ source_image = torch.clamp(source_image, 0, 1)
+
+ # adjust render image resolution and sample intended rendering region
+ render_image_res = np.random.randint(self.render_image_res_low, self.render_image_res_high + 1)
+ render_image = torch.nn.functional.interpolate(
+ rgbs, size=(render_image_res, render_image_res), mode='bicubic', align_corners=True)
+ render_image = torch.clamp(render_image, 0, 1)
+ anchors = torch.randint(
+ 0, render_image_res - self.render_region_size + 1, size=(self.sample_side_views + 1, 2))
+ crop_indices = torch.arange(0, self.render_region_size, device=render_image.device)
+ index_i = (anchors[:, 0].unsqueeze(1) + crop_indices).view(-1, self.render_region_size, 1)
+ index_j = (anchors[:, 1].unsqueeze(1) + crop_indices).view(-1, 1, self.render_region_size)
+ batch_indices = torch.arange(self.sample_side_views + 1, device=render_image.device).view(-1, 1, 1)
+ cropped_render_image = render_image[batch_indices, :, index_i, index_j].permute(0, 3, 1, 2)
+
+ return {
+ 'uid': uid,
+ 'source_camera': source_camera,
+ 'render_camera': render_camera,
+ 'source_image': source_image,
+ 'render_image': cropped_render_image,
+ 'render_anchors': anchors,
+ 'render_full_resolutions': torch.tensor([[render_image_res]], dtype=torch.float32).repeat(self.sample_side_views + 1, 1),
+ 'render_bg_colors': torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1),
+ }
diff --git a/openlrm/launch.py b/openlrm/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..efb425c93344597a291061a36b8eb94a9c0858ef
--- /dev/null
+++ b/openlrm/launch.py
@@ -0,0 +1,36 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+
+from openlrm.runners import REGISTRY_RUNNERS
+
+
+def main():
+
+ parser = argparse.ArgumentParser(description='OpenLRM launcher')
+ parser.add_argument('runner', type=str, help='Runner to launch')
+ args, unknown = parser.parse_known_args()
+
+ if args.runner not in REGISTRY_RUNNERS:
+ raise ValueError('Runner {} not found'.format(args.runner))
+
+ RunnerClass = REGISTRY_RUNNERS[args.runner]
+ with RunnerClass() as runner:
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/openlrm/losses/__init__.py b/openlrm/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed8da8292b9982cddbfaf84ad3ea74b4bfa9925d
--- /dev/null
+++ b/openlrm/losses/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .pixelwise import *
+from .perceptual import *
+from .tvloss import *
diff --git a/openlrm/losses/perceptual.py b/openlrm/losses/perceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eead0d1a207e1863598d3400a4a42bd40549114
--- /dev/null
+++ b/openlrm/losses/perceptual.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+
+__all__ = ['LPIPSLoss']
+
+
+class LPIPSLoss(nn.Module):
+ """
+ Compute LPIPS loss between two images.
+ """
+
+ def __init__(self, device, prefech: bool = False):
+ super().__init__()
+ self.device = device
+ self.cached_models = {}
+ if prefech:
+ self.prefetch_models()
+
+ def _get_model(self, model_name: str):
+ if model_name not in self.cached_models:
+ import warnings
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=UserWarning)
+ import lpips
+ _model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device)
+ _model = torch.compile(_model)
+ self.cached_models[model_name] = _model
+ return self.cached_models[model_name]
+
+ def prefetch_models(self):
+ _model_names = ['alex', 'vgg']
+ for model_name in _model_names:
+ self._get_model(model_name)
+
+ def forward(self, x, y, is_training: bool = True):
+ """
+ Assume images are 0-1 scaled and channel first.
+
+ Args:
+ x: [N, M, C, H, W]
+ y: [N, M, C, H, W]
+ is_training: whether to use VGG or AlexNet.
+
+ Returns:
+ Mean-reduced LPIPS loss across batch.
+ """
+ model_name = 'vgg' if is_training else 'alex'
+ loss_fn = self._get_model(model_name)
+ N, M, C, H, W = x.shape
+ x = x.reshape(N*M, C, H, W)
+ y = y.reshape(N*M, C, H, W)
+ image_loss = loss_fn(x, y, normalize=True).mean(dim=[1, 2, 3])
+ batch_loss = image_loss.reshape(N, M).mean(dim=1)
+ all_loss = batch_loss.mean()
+ return all_loss
diff --git a/openlrm/losses/pixelwise.py b/openlrm/losses/pixelwise.py
new file mode 100644
index 0000000000000000000000000000000000000000..f936d9960041e49baf2ab1334e9639c219212ec2
--- /dev/null
+++ b/openlrm/losses/pixelwise.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+
+__all__ = ['PixelLoss']
+
+
+class PixelLoss(nn.Module):
+ """
+ Pixel-wise loss between two images.
+ """
+
+ def __init__(self, option: str = 'mse'):
+ super().__init__()
+ self.loss_fn = self._build_from_option(option)
+
+ @staticmethod
+ def _build_from_option(option: str, reduction: str = 'none'):
+ if option == 'mse':
+ return nn.MSELoss(reduction=reduction)
+ elif option == 'l1':
+ return nn.L1Loss(reduction=reduction)
+ else:
+ raise NotImplementedError(f'Unknown pixel loss option: {option}')
+
+ @torch.compile
+ def forward(self, x, y):
+ """
+ Assume images are channel first.
+
+ Args:
+ x: [N, M, C, H, W]
+ y: [N, M, C, H, W]
+
+ Returns:
+ Mean-reduced pixel loss across batch.
+ """
+ N, M, C, H, W = x.shape
+ x = x.reshape(N*M, C, H, W)
+ y = y.reshape(N*M, C, H, W)
+ image_loss = self.loss_fn(x, y).mean(dim=[1, 2, 3])
+ batch_loss = image_loss.reshape(N, M).mean(dim=1)
+ all_loss = batch_loss.mean()
+ return all_loss
diff --git a/openlrm/losses/tvloss.py b/openlrm/losses/tvloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..77a13b69b6f9fcacc38940373bf8159b3cf61459
--- /dev/null
+++ b/openlrm/losses/tvloss.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+
+__all__ = ['TVLoss']
+
+
+class TVLoss(nn.Module):
+ """
+ Total variance loss.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def numel_excluding_first_dim(self, x):
+ return x.numel() // x.shape[0]
+
+ @torch.compile
+ def forward(self, x):
+ """
+ Assume batched and channel first with inner sizes.
+
+ Args:
+ x: [N, M, C, H, W]
+
+ Returns:
+ Mean-reduced TV loss with element-level scaling.
+ """
+ N, M, C, H, W = x.shape
+ x = x.reshape(N*M, C, H, W)
+ diff_i = x[..., 1:, :] - x[..., :-1, :]
+ diff_j = x[..., :, 1:] - x[..., :, :-1]
+ div_i = self.numel_excluding_first_dim(diff_i)
+ div_j = self.numel_excluding_first_dim(diff_j)
+ tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i
+ tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j
+ tv = tv_i + tv_j
+ batch_tv = tv.reshape(N, M).mean(dim=1)
+ all_tv = batch_tv.mean()
+ return all_tv
diff --git a/openlrm/models/__init__.py b/openlrm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abde0f5ffc44d03e5bba48e09eb442e73eeb2adc
--- /dev/null
+++ b/openlrm/models/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .modeling_lrm import ModelLRM
+
+
+model_dict = {
+ 'lrm': ModelLRM,
+}
diff --git a/openlrm/models/block.py b/openlrm/models/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5eb00b873fbf05928c078c9d379c8728628150d
--- /dev/null
+++ b/openlrm/models/block.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch.nn as nn
+import loratorch as lora
+from .modulate import ModLN
+
+
+class BasicBlock(nn.Module):
+ """
+ Transformer block that is in its simplest form.
+ Designed for PF-LRM architecture.
+ """
+ # Block contains a self-attention layer and an MLP
+ def __init__(self, inner_dim: int, num_heads: int, eps: float,
+ attn_drop: float = 0., attn_bias: bool = False,
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
+ self.self_attn = nn.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads,
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+ nn.GELU(),
+ nn.Dropout(mlp_drop),
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+ nn.Dropout(mlp_drop),
+ )
+
+ def forward(self, x):
+ # x: [N, L, D]
+ before_sa = self.norm1(x)
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class ConditionBlock(nn.Module):
+ """
+ Transformer block that takes in a cross-attention condition.
+ Designed for SparseLRM architecture.
+ """
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
+ def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
+ attn_drop: float = 0., attn_bias: bool = False,
+ mlp_ratio: float = 4., mlp_drop: float = 0.,
+ lora_rank: int = 0):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
+ self.cross_attn = lora.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
+ dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank)
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
+ self.self_attn = lora.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads,
+ dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank)
+ self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+ nn.GELU(),
+ nn.Dropout(mlp_drop),
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+ nn.Dropout(mlp_drop),
+ )
+
+ def forward(self, x, cond):
+ # x: [N, L, D]
+ # cond: [N, L_cond, D_cond]
+ x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
+ before_sa = self.norm2(x)
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
+ x = x + self.mlp(self.norm3(x))
+ return x
+
+
+class ConditionModulationBlock(nn.Module):
+ """
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
+ Designed for raw LRM architecture.
+ """
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
+ def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
+ attn_drop: float = 0., attn_bias: bool = False,
+ mlp_ratio: float = 4., mlp_drop: float = 0.,
+ lora_rank: int = 0):
+ super().__init__()
+ self.norm1 = ModLN(inner_dim, mod_dim, eps)
+ self.cross_attn = lora.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
+ dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank)
+ self.norm2 = ModLN(inner_dim, mod_dim, eps)
+ self.self_attn = lora.MultiheadAttention(
+ embed_dim=inner_dim, num_heads=num_heads,
+ dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank)
+ self.norm3 = ModLN(inner_dim, mod_dim, eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+ nn.GELU(),
+ nn.Dropout(mlp_drop),
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+ nn.Dropout(mlp_drop),
+ )
+
+ def forward(self, x, cond, mod):
+ # x: [N, L, D]
+ # cond: [N, L_cond, D_cond]
+ # mod: [N, D_mod]
+ x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
+ before_sa = self.norm2(x, mod)
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
+ x = x + self.mlp(self.norm3(x, mod))
+ return x
diff --git a/openlrm/models/embedder.py b/openlrm/models/embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..379721cf5c146cb29aca7695cff8558b0d23673c
--- /dev/null
+++ b/openlrm/models/embedder.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+
+
+class CameraEmbedder(nn.Module):
+ """
+ Embed camera features to a high-dimensional vector.
+
+ Reference:
+ DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
+ """
+ def __init__(self, raw_dim: int, embed_dim: int):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(raw_dim, embed_dim),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim),
+ )
+
+ @torch.compile
+ def forward(self, x):
+ return self.mlp(x)
diff --git a/openlrm/models/encoders/__init__.py b/openlrm/models/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/openlrm/models/encoders/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/openlrm/models/encoders/dino_wrapper.py b/openlrm/models/encoders/dino_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..77f475781af22630f3989ff912283fdc2e1b13b4
--- /dev/null
+++ b/openlrm/models/encoders/dino_wrapper.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from transformers import ViTImageProcessor, ViTModel
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class DinoWrapper(nn.Module):
+ """
+ Dino v1 wrapper using huggingface transformer implementation.
+ """
+ def __init__(self, model_name: str, freeze: bool = True):
+ super().__init__()
+ self.model, self.processor = self._build_dino(model_name)
+ if freeze:
+ self._freeze()
+
+ @torch.compile
+ def forward_model(self, inputs):
+ return self.model(**inputs, interpolate_pos_encoding=True)
+
+ def forward(self, image):
+ # image: [N, C, H, W], on cpu
+ # RGB image with [0,1] scale and properly sized
+ inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device)
+ # This resampling of positional embedding uses bicubic interpolation
+ outputs = self.forward_model(inputs)
+ last_hidden_states = outputs.last_hidden_state
+ return last_hidden_states
+
+ def _freeze(self):
+ logger.warning(f"======== Freezing DinoWrapper ========")
+ self.model.eval()
+ for name, param in self.model.named_parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
+ import requests
+ try:
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
+ processor = ViTImageProcessor.from_pretrained(model_name)
+ return model, processor
+ except requests.exceptions.ProxyError as err:
+ if proxy_error_retries > 0:
+ print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...")
+ import time
+ time.sleep(proxy_error_cooldown)
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
+ else:
+ raise err
diff --git a/openlrm/models/encoders/dinov2/__init__.py b/openlrm/models/encoders/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/openlrm/models/encoders/dinov2/hub/__init__.py b/openlrm/models/encoders/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/openlrm/models/encoders/dinov2/hub/backbones.py b/openlrm/models/encoders/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fd8c4010204da1f1e413db66d24a87e2a39a358
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/backbones.py
@@ -0,0 +1,166 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ # ********** Modified by Zexin He in 2023-2024 **********
+ state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern
+ if vit_kwargs.get("modulation_dim") is not None:
+ state_dict = {
+ k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v
+ for k, v in state_dict.items()
+ }
+ model.load_state_dict(state_dict, strict=False)
+ else:
+ model.load_state_dict(state_dict, strict=True)
+ # ********************************************************
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/openlrm/models/encoders/dinov2/hub/classifiers.py b/openlrm/models/encoders/dinov2/hub/classifiers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f0841efa80ab3d564cd320d61da254af182606b
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/classifiers.py
@@ -0,0 +1,268 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+import torch.nn as nn
+
+from .backbones import _make_dinov2_model
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ IMAGENET1K = "IMAGENET1K"
+
+
+def _make_dinov2_linear_classification_head(
+ *,
+ arch_name: str = "vit_large",
+ patch_size: int = 14,
+ embed_dim: int = 1024,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ num_register_tokens: int = 0,
+ **kwargs,
+):
+ if layers not in (1, 4):
+ raise AssertionError(f"Unsupported number of layers: {layers}")
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
+
+ if pretrained:
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ layers_str = str(layers) if layers == 4 else ""
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ linear_head.load_state_dict(state_dict, strict=True)
+
+ return linear_head
+
+
+class _LinearClassifierWrapper(nn.Module):
+ def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
+ super().__init__()
+ self.backbone = backbone
+ self.linear_head = linear_head
+ self.layers = layers
+
+ def forward(self, x):
+ if self.layers == 1:
+ x = self.backbone.forward_features(x)
+ cls_token = x["x_norm_clstoken"]
+ patch_tokens = x["x_norm_patchtokens"]
+ # fmt: off
+ linear_input = torch.cat([
+ cls_token,
+ patch_tokens.mean(dim=1),
+ ], dim=1)
+ # fmt: on
+ elif self.layers == 4:
+ x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
+ # fmt: off
+ linear_input = torch.cat([
+ x[0][1],
+ x[1][1],
+ x[2][1],
+ x[3][1],
+ x[3][0].mean(dim=1),
+ ], dim=1)
+ # fmt: on
+ else:
+ assert False, f"Unsupported number of layers: {self.layers}"
+ return self.linear_head(linear_input)
+
+
+def _make_dinov2_linear_classifier(
+ *,
+ arch_name: str = "vit_large",
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ **kwargs,
+):
+ backbone = _make_dinov2_model(
+ arch_name=arch_name,
+ pretrained=pretrained,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ **kwargs,
+ )
+
+ embed_dim = backbone.embed_dim
+ patch_size = backbone.patch_size
+ linear_head = _make_dinov2_linear_classification_head(
+ arch_name=arch_name,
+ patch_size=patch_size,
+ embed_dim=embed_dim,
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=num_register_tokens,
+ )
+
+ return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
+
+
+def dinov2_vits14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_small",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_base",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_large",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_lc(
+ *,
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.IMAGENET1K,
+ **kwargs,
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_giant2",
+ layers=layers,
+ ffn_layer="swiglufused",
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_small",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_base",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_large",
+ layers=layers,
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg_lc(
+ *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+ """
+ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+ """
+ return _make_dinov2_linear_classifier(
+ arch_name="vit_giant2",
+ layers=layers,
+ ffn_layer="swiglufused",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/openlrm/models/encoders/dinov2/hub/depth/__init__.py b/openlrm/models/encoders/dinov2/hub/depth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..91716e58ab6158d814df8c653644d9af4c7be65c
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/depth/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .decode_heads import BNHead, DPTHead
+from .encoder_decoder import DepthEncoderDecoder
diff --git a/openlrm/models/encoders/dinov2/hub/depth/decode_heads.py b/openlrm/models/encoders/dinov2/hub/depth/decode_heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..f455accad38fec6ecdd53460233a564c34f434da
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/depth/decode_heads.py
@@ -0,0 +1,747 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import copy
+from functools import partial
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+
+from .ops import resize
+
+
+# XXX: (Untested) replacement for mmcv.imdenormalize()
+def _imdenormalize(img, mean, std, to_bgr=True):
+ import numpy as np
+
+ mean = mean.reshape(1, -1).astype(np.float64)
+ std = std.reshape(1, -1).astype(np.float64)
+ img = (img * std) + mean
+ if to_bgr:
+ img = img[::-1]
+ return img
+
+
+class DepthBaseDecodeHead(nn.Module):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (List): Input channels.
+ channels (int): Channels after modules, before conv_depth.
+ conv_layer (nn.Module): Conv layers. Default: None.
+ act_layer (nn.Module): Activation layers. Default: nn.ReLU.
+ loss_decode (dict): Config of decode loss.
+ Default: ().
+ sampler (dict|None): The config of depth map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ min_depth (int): Min depth in dataset setting.
+ Default: 1e-3.
+ max_depth (int): Max depth in dataset setting.
+ Default: None.
+ norm_layer (dict|None): Norm layers.
+ Default: None.
+ classify (bool): Whether predict depth in a cls.-reg. manner.
+ Default: False.
+ n_bins (int): The number of bins used in cls. step.
+ Default: 256.
+ bins_strategy (str): The discrete strategy used in cls. step.
+ Default: 'UD'.
+ norm_strategy (str): The norm strategy on cls. probability
+ distribution. Default: 'linear'
+ scale_up (str): Whether predict depth in a scale-up manner.
+ Default: False.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ conv_layer=None,
+ act_layer=nn.ReLU,
+ channels=96,
+ loss_decode=(),
+ sampler=None,
+ align_corners=False,
+ min_depth=1e-3,
+ max_depth=None,
+ norm_layer=None,
+ classify=False,
+ n_bins=256,
+ bins_strategy="UD",
+ norm_strategy="linear",
+ scale_up=False,
+ ):
+ super(DepthBaseDecodeHead, self).__init__()
+
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conf_layer = conv_layer
+ self.act_layer = act_layer
+ self.loss_decode = loss_decode
+ self.align_corners = align_corners
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.norm_layer = norm_layer
+ self.classify = classify
+ self.n_bins = n_bins
+ self.scale_up = scale_up
+
+ if self.classify:
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
+
+ self.bins_strategy = bins_strategy
+ self.norm_strategy = norm_strategy
+ self.softmax = nn.Softmax(dim=1)
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
+ else:
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
+
+ self.relu = nn.ReLU()
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, inputs, img_metas):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, img, inputs, img_metas, depth_gt):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ depth_gt (Tensor): GT depth
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ depth_pred = self.forward(inputs, img_metas)
+ losses = self.losses(depth_pred, depth_gt)
+
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
+ losses.update(**log_imgs)
+
+ return losses
+
+ def forward_test(self, inputs, img_metas):
+ """Forward function for testing.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+
+ Returns:
+ Tensor: Output depth map.
+ """
+ return self.forward(inputs, img_metas)
+
+ def depth_pred(self, feat):
+ """Prediction each pixel."""
+ if self.classify:
+ logit = self.conv_depth(feat)
+
+ if self.bins_strategy == "UD":
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
+ elif self.bins_strategy == "SID":
+ bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
+
+ # following Adabins, default linear
+ if self.norm_strategy == "linear":
+ logit = torch.relu(logit)
+ eps = 0.1
+ logit = logit + eps
+ logit = logit / logit.sum(dim=1, keepdim=True)
+ elif self.norm_strategy == "softmax":
+ logit = torch.softmax(logit, dim=1)
+ elif self.norm_strategy == "sigmoid":
+ logit = torch.sigmoid(logit)
+ logit = logit / logit.sum(dim=1, keepdim=True)
+
+ output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
+
+ else:
+ if self.scale_up:
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
+ else:
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
+ return output
+
+ def losses(self, depth_pred, depth_gt):
+ """Compute depth loss."""
+ loss = dict()
+ depth_pred = resize(
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
+ )
+ if not isinstance(self.loss_decode, nn.ModuleList):
+ losses_decode = [self.loss_decode]
+ else:
+ losses_decode = self.loss_decode
+ for loss_decode in losses_decode:
+ if loss_decode.loss_name not in loss:
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
+ else:
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
+ return loss
+
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
+ import numpy as np
+
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
+ show_img = show_img.numpy().astype(np.float32)
+ show_img = _imdenormalize(
+ show_img,
+ img_meta["img_norm_cfg"]["mean"],
+ img_meta["img_norm_cfg"]["std"],
+ img_meta["img_norm_cfg"]["to_rgb"],
+ )
+ show_img = np.clip(show_img, 0, 255)
+ show_img = show_img.astype(np.uint8)
+ show_img = show_img[:, :, ::-1]
+ show_img = show_img.transpose(0, 2, 1)
+ show_img = show_img.transpose(1, 0, 2)
+
+ depth_pred = depth_pred / torch.max(depth_pred)
+ depth_gt = depth_gt / torch.max(depth_gt)
+
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
+
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
+
+
+class BNHead(DepthBaseDecodeHead):
+ """Just a batchnorm."""
+
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
+ super().__init__(**kwargs)
+ self.input_transform = input_transform
+ self.in_index = in_index
+ self.upsample = upsample
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
+ if self.classify:
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
+ else:
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if "concat" in self.input_transform:
+ inputs = [inputs[i] for i in self.in_index]
+ if "resize" in self.input_transform:
+ inputs = [
+ resize(
+ input=x,
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
+ mode="bilinear",
+ align_corners=self.align_corners,
+ )
+ for x in inputs
+ ]
+ inputs = torch.cat(inputs, dim=1)
+ elif self.input_transform == "multiple_select":
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
+ """Forward function for feature maps before classifying each pixel with
+ ``self.cls_seg`` fc.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ Returns:
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
+ H, W) which is feature map for last layer of decoder head.
+ """
+ # accept lists (for cls token)
+ inputs = list(inputs)
+ for i, x in enumerate(inputs):
+ if len(x) == 2:
+ x, cls_token = x[0], x[1]
+ if len(x.shape) == 2:
+ x = x[:, :, None, None]
+ cls_token = cls_token[:, :, None, None].expand_as(x)
+ inputs[i] = torch.cat((x, cls_token), 1)
+ else:
+ x = x[0]
+ if len(x.shape) == 2:
+ x = x[:, :, None, None]
+ inputs[i] = x
+ x = self._transform_inputs(inputs)
+ # feats = self.bn(x)
+ return x
+
+ def forward(self, inputs, img_metas=None, **kwargs):
+ """Forward function."""
+ output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
+ output = self.depth_pred(output)
+ return output
+
+
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
+ False. Default: "auto".
+ conv_layer (nn.Module): Convolution layer. Default: None,
+ which means using conv2d.
+ norm_layer (nn.Module): Normalization layer. Default: None.
+ act_layer (nn.Module): Activation layer. Default: nn.ReLU.
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+
+ _abbr_ = "conv_block"
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias="auto",
+ conv_layer=nn.Conv2d,
+ norm_layer=None,
+ act_layer=nn.ReLU,
+ inplace=True,
+ with_spectral_norm=False,
+ padding_mode="zeros",
+ order=("conv", "norm", "act"),
+ ):
+ super(ConvModule, self).__init__()
+ official_padding_mode = ["zeros", "circular"]
+ self.conv_layer = conv_layer
+ self.norm_layer = norm_layer
+ self.act_layer = act_layer
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(["conv", "norm", "act"])
+
+ self.with_norm = norm_layer is not None
+ self.with_activation = act_layer is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == "auto":
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_explicit_padding:
+ if padding_mode == "zeros":
+ padding_layer = nn.ZeroPad2d
+ else:
+ raise AssertionError(f"Unsupported padding mode: {padding_mode}")
+ self.pad = padding_layer(padding)
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = self.conv_layer(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index("norm") > order.index("conv"):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ norm = partial(norm_layer, num_features=norm_channels)
+ self.add_module("norm", norm)
+ if self.with_bias:
+ from torch.nnModules.batchnorm import _BatchNorm
+ from torch.nnModules.instancenorm import _InstanceNorm
+
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn("Unnecessary conv bias before batch/instance norm")
+ else:
+ self.norm_name = None
+
+ # build activation layer
+ if self.with_activation:
+ # nn.Tanh has no 'inplace' argument
+ # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
+ if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
+ act_layer = partial(act_layer, inplace=inplace)
+ self.activate = act_layer()
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, "init_weights"):
+ if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
+ nonlinearity = "leaky_relu"
+ a = 0.01 # XXX: default negative_slope
+ else:
+ nonlinearity = "relu"
+ a = 0
+ if hasattr(self.conv, "weight") and self.conv.weight is not None:
+ nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
+ if hasattr(self.conv, "bias") and self.conv.bias is not None:
+ nn.init.constant_(self.conv.bias, 0)
+ if self.with_norm:
+ if hasattr(self.norm, "weight") and self.norm.weight is not None:
+ nn.init.constant_(self.norm.weight, 1)
+ if hasattr(self.norm, "bias") and self.norm.bias is not None:
+ nn.init.constant_(self.norm.bias, 0)
+
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == "conv":
+ if self.with_explicit_padding:
+ x = self.pad(x)
+ x = self.conv(x)
+ elif layer == "norm" and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == "act" and activate and self.with_activation:
+ x = self.activate(x)
+ return x
+
+
+class Interpolate(nn.Module):
+ def __init__(self, scale_factor, mode, align_corners=False):
+ super(Interpolate, self).__init__()
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+ return x
+
+
+class HeadDepth(nn.Module):
+ def __init__(self, features):
+ super(HeadDepth, self).__init__()
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(self, x):
+ x = self.head(x)
+ return x
+
+
+class ReassembleBlocks(nn.Module):
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
+ rearrange the feature vector to feature map.
+ Args:
+ in_channels (int): ViT feature channels. Default: 768.
+ out_channels (List): output channels of each stage.
+ Default: [96, 192, 384, 768].
+ readout_type (str): Type of readout operation. Default: 'ignore'.
+ patch_size (int): The patch size. Default: 16.
+ """
+
+ def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
+ super(ReassembleBlocks, self).__init__()
+
+ assert readout_type in ["ignore", "add", "project"]
+ self.readout_type = readout_type
+ self.patch_size = patch_size
+
+ self.projects = nn.ModuleList(
+ [
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ act_layer=None,
+ )
+ for out_channel in out_channels
+ ]
+ )
+
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+ if self.readout_type == "project":
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
+
+ def forward(self, inputs):
+ assert isinstance(inputs, list)
+ out = []
+ for i, x in enumerate(inputs):
+ assert len(x) == 2
+ x, cls_token = x[0], x[1]
+ feature_shape = x.shape
+ if self.readout_type == "project":
+ x = x.flatten(2).permute((0, 2, 1))
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ x = x.permute(0, 2, 1).reshape(feature_shape)
+ elif self.readout_type == "add":
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
+ x = x.reshape(feature_shape)
+ else:
+ pass
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+ out.append(x)
+ return out
+
+
+class PreActResidualConvUnit(nn.Module):
+ """ResidualConvUnit, pre-activate residual unit.
+ Args:
+ in_channels (int): number of channels in the input feature map.
+ act_layer (nn.Module): activation layer.
+ norm_layer (nn.Module): norm layer.
+ stride (int): stride of the first block. Default: 1
+ dilation (int): dilation rate for convs layers. Default: 1.
+ """
+
+ def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
+ super(PreActResidualConvUnit, self).__init__()
+
+ self.conv1 = ConvModule(
+ in_channels,
+ in_channels,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ bias=False,
+ order=("act", "conv", "norm"),
+ )
+
+ self.conv2 = ConvModule(
+ in_channels,
+ in_channels,
+ 3,
+ padding=1,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ bias=False,
+ order=("act", "conv", "norm"),
+ )
+
+ def forward(self, inputs):
+ inputs_ = inputs.clone()
+ x = self.conv1(inputs)
+ x = self.conv2(x)
+ return x + inputs_
+
+
+class FeatureFusionBlock(nn.Module):
+ """FeatureFusionBlock, merge feature map from different stages.
+ Args:
+ in_channels (int): Input channels.
+ act_layer (nn.Module): activation layer for ResidualConvUnit.
+ norm_layer (nn.Module): normalization layer.
+ expand (bool): Whether expand the channels in post process block.
+ Default: False.
+ align_corners (bool): align_corner setting for bilinear upsample.
+ Default: True.
+ """
+
+ def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
+ super(FeatureFusionBlock, self).__init__()
+
+ self.in_channels = in_channels
+ self.expand = expand
+ self.align_corners = align_corners
+
+ self.out_channels = in_channels
+ if self.expand:
+ self.out_channels = in_channels // 2
+
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
+
+ self.res_conv_unit1 = PreActResidualConvUnit(
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
+ )
+ self.res_conv_unit2 = PreActResidualConvUnit(
+ in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
+ )
+
+ def forward(self, *inputs):
+ x = inputs[0]
+ if len(inputs) == 2:
+ if x.shape != inputs[1].shape:
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
+ else:
+ res = inputs[1]
+ x = x + self.res_conv_unit1(res)
+ x = self.res_conv_unit2(x)
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
+ x = self.project(x)
+ return x
+
+
+class DPTHead(DepthBaseDecodeHead):
+ """Vision Transformers for Dense Prediction.
+ This head is implemented of `DPT `_.
+ Args:
+ embed_dims (int): The embed dimension of the ViT backbone.
+ Default: 768.
+ post_process_channels (List): Out channels of post process conv
+ layers. Default: [96, 192, 384, 768].
+ readout_type (str): Type of readout operation. Default: 'ignore'.
+ patch_size (int): The patch size. Default: 16.
+ expand_channels (bool): Whether expand the channels in post process
+ block. Default: False.
+ """
+
+ def __init__(
+ self,
+ embed_dims=768,
+ post_process_channels=[96, 192, 384, 768],
+ readout_type="ignore",
+ patch_size=16,
+ expand_channels=False,
+ **kwargs,
+ ):
+ super(DPTHead, self).__init__(**kwargs)
+
+ self.in_channels = self.in_channels
+ self.expand_channels = expand_channels
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
+
+ self.post_process_channels = [
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
+ ]
+ self.convs = nn.ModuleList()
+ for channel in self.post_process_channels:
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
+ self.fusion_blocks = nn.ModuleList()
+ for _ in range(len(self.convs)):
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
+ self.fusion_blocks[0].res_conv_unit1 = None
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
+ self.num_fusion_blocks = len(self.fusion_blocks)
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
+ self.num_post_process_channels = len(self.post_process_channels)
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
+ assert self.num_reassemble_blocks == self.num_post_process_channels
+ self.conv_depth = HeadDepth(self.channels)
+
+ def forward(self, inputs, img_metas):
+ assert len(inputs) == self.num_reassemble_blocks
+ x = [inp for inp in inputs]
+ x = self.reassemble_blocks(x)
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
+ out = self.fusion_blocks[0](x[-1])
+ for i in range(1, len(self.fusion_blocks)):
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
+ out = self.project(out)
+ out = self.depth_pred(out)
+ return out
diff --git a/openlrm/models/encoders/dinov2/hub/depth/encoder_decoder.py b/openlrm/models/encoders/dinov2/hub/depth/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb29ced67957a336e763b0e7c90c0eeaea36fea8
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/depth/encoder_decoder.py
@@ -0,0 +1,351 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .ops import resize
+
+
+def add_prefix(inputs, prefix):
+ """Add prefix for dict.
+
+ Args:
+ inputs (dict): The input dict with str keys.
+ prefix (str): The prefix to add.
+
+ Returns:
+
+ dict: The dict with keys updated with ``prefix``.
+ """
+
+ outputs = dict()
+ for name, value in inputs.items():
+ outputs[f"{prefix}.{name}"] = value
+
+ return outputs
+
+
+class DepthEncoderDecoder(nn.Module):
+ """Encoder Decoder depther.
+
+ EncoderDecoder typically consists of backbone and decode_head.
+ """
+
+ def __init__(self, backbone, decode_head):
+ super(DepthEncoderDecoder, self).__init__()
+
+ self.backbone = backbone
+ self.decode_head = decode_head
+ self.align_corners = self.decode_head.align_corners
+
+ def extract_feat(self, img):
+ """Extract features from images."""
+ return self.backbone(img)
+
+ def encode_decode(self, img, img_metas, rescale=True, size=None):
+ """Encode images with backbone and decode into a depth estimation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self._decode_head_forward_test(x, img_metas)
+ # crop the pred depth to the certain range.
+ out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
+ if rescale:
+ if size is None:
+ if img_metas is not None:
+ size = img_metas[0]["ori_shape"][:2]
+ else:
+ size = img.shape[2:]
+ out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
+ losses.update(add_prefix(loss_decode, "decode"))
+ return losses
+
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ depth_pred = self.decode_head.forward_test(x, img_metas)
+ return depth_pred
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ depth = self.encode_decode(img, None)
+
+ return depth
+
+ def forward_train(self, img, img_metas, depth_gt, **kwargs):
+ """Forward function for training.
+
+ Args:
+ img (Tensor): Input images.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ depth_gt (Tensor): Depth gt
+ used if the architecture supports depth estimation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ # the last of x saves the info from neck
+ loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
+
+ losses.update(loss_decode)
+
+ return losses
+
+ def whole_inference(self, img, img_meta, rescale, size=None):
+ """Inference with full image."""
+ return self.encode_decode(img, img_meta, rescale, size=size)
+
+ def slide_inference(self, img, img_meta, rescale, stride, crop_size):
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+ """
+
+ h_stride, w_stride = stride
+ h_crop, w_crop = crop_size
+ batch_size, _, h_img, w_img = img.size()
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((batch_size, 1, h_img, w_img))
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ depth_pred = self.encode_decode(crop_img, img_meta, rescale)
+ preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ if torch.onnx.is_in_onnx_export():
+ # cast count_mat to constant while exporting to ONNX
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
+ preds = preds / count_mat
+ return preds
+
+ def inference(self, img, img_meta, rescale, size=None, mode="whole"):
+ """Inference with slide/whole style.
+
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `depth/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+
+ Returns:
+ Tensor: The output depth map.
+ """
+
+ assert mode in ["slide", "whole"]
+ ori_shape = img_meta[0]["ori_shape"]
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
+ if mode == "slide":
+ depth_pred = self.slide_inference(img, img_meta, rescale)
+ else:
+ depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
+ output = depth_pred
+ flip = img_meta[0]["flip"]
+ if flip:
+ flip_direction = img_meta[0]["flip_direction"]
+ assert flip_direction in ["horizontal", "vertical"]
+ if flip_direction == "horizontal":
+ output = output.flip(dims=(3,))
+ elif flip_direction == "vertical":
+ output = output.flip(dims=(2,))
+
+ return output
+
+ def simple_test(self, img, img_meta, rescale=True):
+ """Simple test with single image."""
+ depth_pred = self.inference(img, img_meta, rescale)
+ if torch.onnx.is_in_onnx_export():
+ # our inference backend only support 4D output
+ depth_pred = depth_pred.unsqueeze(0)
+ return depth_pred
+ depth_pred = depth_pred.cpu().numpy()
+ # unravel batch dim
+ depth_pred = list(depth_pred)
+ return depth_pred
+
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Test with augmentations.
+
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented depth logit inplace
+ depth_pred = self.inference(imgs[0], img_metas[0], rescale)
+ for i in range(1, len(imgs)):
+ cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
+ depth_pred += cur_depth_pred
+ depth_pred /= len(imgs)
+ depth_pred = depth_pred.cpu().numpy()
+ # unravel batch dim
+ depth_pred = list(depth_pred)
+ return depth_pred
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
+ if not isinstance(var, list):
+ raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
+ # all images in the same aug batch all of the same ori_shape and pad
+ # shape
+ for img_meta in img_metas:
+ ori_shapes = [_["ori_shape"] for _ in img_meta]
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
+ img_shapes = [_["img_shape"] for _ in img_meta]
+ assert all(shape == img_shapes[0] for shape in img_shapes)
+ pad_shapes = [_["pad_shape"] for _ in img_meta]
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+ ``num_samples``.
+ ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ ``log_vars`` contains all the variables to be sent to the
+ logger.
+ ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data_batch)
+
+ # split losses and images
+ real_losses = {}
+ log_imgs = {}
+ for k, v in losses.items():
+ if "img" in k:
+ log_imgs[k] = v
+ else:
+ real_losses[k] = v
+
+ loss, log_vars = self._parse_losses(real_losses)
+
+ outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
+
+ return outputs
+
+ def val_step(self, data_batch, **kwargs):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ output = self(**data_batch, **kwargs)
+ return output
+
+ @staticmethod
+ def _parse_losses(losses):
+ import torch.distributed as dist
+
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+ which may be a weighted sum of all losses, log_vars contains
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(f"{loss_name} is not a tensor or list of tensors")
+
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
+
+ log_vars["loss"] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
diff --git a/openlrm/models/encoders/dinov2/hub/depth/ops.py b/openlrm/models/encoders/dinov2/hub/depth/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/depth/ops.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import warnings
+
+import torch.nn.functional as F
+
+
+def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if (
+ (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
+ and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)
+ ):
+ warnings.warn(
+ f"When align_corners={align_corners}, "
+ "the output would more aligned if "
+ f"input size {(input_h, input_w)} is `x+1` and "
+ f"out size {(output_h, output_w)} is `nx+1`"
+ )
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/openlrm/models/encoders/dinov2/hub/depthers.py b/openlrm/models/encoders/dinov2/hub/depthers.py
new file mode 100644
index 0000000000000000000000000000000000000000..f88b7e9a41056594e3b3e66107feee98bffab820
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/depthers.py
@@ -0,0 +1,246 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from functools import partial
+from typing import Optional, Tuple, Union
+
+import torch
+
+from .backbones import _make_dinov2_model
+from .depth import BNHead, DepthEncoderDecoder, DPTHead
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
+
+
+class Weights(Enum):
+ NYU = "NYU"
+ KITTI = "KITTI"
+
+
+def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
+ if not pretrained: # Default
+ return (0.001, 10.0)
+
+ # Pretrained, set according to the training dataset for the provided weights
+ if weights == Weights.KITTI:
+ return (0.001, 80.0)
+
+ if weights == Weights.NYU:
+ return (0.001, 10.0)
+
+ return (0.001, 10.0)
+
+
+def _make_dinov2_linear_depth_head(
+ *,
+ embed_dim: int,
+ layers: int,
+ min_depth: float,
+ max_depth: float,
+ **kwargs,
+):
+ if layers not in (1, 4):
+ raise AssertionError(f"Unsupported number of layers: {layers}")
+
+ if layers == 1:
+ in_index = [0]
+ else:
+ assert layers == 4
+ in_index = [0, 1, 2, 3]
+
+ return BNHead(
+ classify=True,
+ n_bins=256,
+ bins_strategy="UD",
+ norm_strategy="linear",
+ upsample=4,
+ in_channels=[embed_dim] * len(in_index),
+ in_index=in_index,
+ input_transform="resize_concat",
+ channels=embed_dim * len(in_index) * 2,
+ align_corners=False,
+ min_depth=0.001,
+ max_depth=80,
+ loss_decode=(),
+ )
+
+
+def _make_dinov2_linear_depther(
+ *,
+ arch_name: str = "vit_large",
+ layers: int = 4,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.NYU,
+ depth_range: Optional[Tuple[float, float]] = None,
+ **kwargs,
+):
+ if layers not in (1, 4):
+ raise AssertionError(f"Unsupported number of layers: {layers}")
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ if depth_range is None:
+ depth_range = _get_depth_range(pretrained, weights)
+ min_depth, max_depth = depth_range
+
+ backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
+
+ embed_dim = backbone.embed_dim
+ patch_size = backbone.patch_size
+ model_name = _make_dinov2_model_name(arch_name, patch_size)
+ linear_depth_head = _make_dinov2_linear_depth_head(
+ embed_dim=embed_dim,
+ layers=layers,
+ min_depth=min_depth,
+ max_depth=max_depth,
+ )
+
+ layer_count = {
+ "vit_small": 12,
+ "vit_base": 12,
+ "vit_large": 24,
+ "vit_giant2": 40,
+ }[arch_name]
+
+ if layers == 4:
+ out_index = {
+ "vit_small": [2, 5, 8, 11],
+ "vit_base": [2, 5, 8, 11],
+ "vit_large": [4, 11, 17, 23],
+ "vit_giant2": [9, 19, 29, 39],
+ }[arch_name]
+ else:
+ assert layers == 1
+ out_index = [layer_count - 1]
+
+ model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
+ model.backbone.forward = partial(
+ backbone.get_intermediate_layers,
+ n=out_index,
+ reshape=True,
+ return_class_token=True,
+ norm=False,
+ )
+ model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
+
+ if pretrained:
+ layers_str = str(layers) if layers == 4 else ""
+ weights_str = weights.value.lower()
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
+ checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_linear_depther(
+ arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_linear_depther(
+ arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_linear_depther(
+ arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_linear_depther(
+ arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
+ )
+
+
+def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
+ return DPTHead(
+ in_channels=[embed_dim] * 4,
+ channels=256,
+ embed_dims=embed_dim,
+ post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
+ readout_type="project",
+ min_depth=min_depth,
+ max_depth=max_depth,
+ loss_decode=(),
+ )
+
+
+def _make_dinov2_dpt_depther(
+ *,
+ arch_name: str = "vit_large",
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.NYU,
+ depth_range: Optional[Tuple[float, float]] = None,
+ **kwargs,
+):
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ if depth_range is None:
+ depth_range = _get_depth_range(pretrained, weights)
+ min_depth, max_depth = depth_range
+
+ backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
+
+ model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
+ dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
+
+ out_index = {
+ "vit_small": [2, 5, 8, 11],
+ "vit_base": [2, 5, 8, 11],
+ "vit_large": [4, 11, 17, 23],
+ "vit_giant2": [9, 19, 29, 39],
+ }[arch_name]
+
+ model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
+ model.backbone.forward = partial(
+ backbone.get_intermediate_layers,
+ n=out_index,
+ reshape=True,
+ return_class_token=True,
+ norm=False,
+ )
+ model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
+
+ if pretrained:
+ weights_str = weights.value.lower()
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
+ checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+ return _make_dinov2_dpt_depther(
+ arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
+ )
diff --git a/openlrm/models/encoders/dinov2/hub/utils.py b/openlrm/models/encoders/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
diff --git a/openlrm/models/encoders/dinov2/layers/__init__.py b/openlrm/models/encoders/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..77967aa6ccfae24c39b8e167c83dd77073fd68fb
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# ******************************************************************************
+# Code modified by Zexin He in 2023-2024.
+# Modifications are marked with clearly visible comments
+# licensed under the Apache License, Version 2.0.
+# ******************************************************************************
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+# ********** Modified by Zexin He in 2023-2024 **********
+# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
+from .block import Block, BlockWithModulation
+# ********************************************************
+from .attention import MemEffAttention
diff --git a/openlrm/models/encoders/dinov2/layers/attention.py b/openlrm/models/encoders/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb76ef2816164729a58cceb18d0f000cfb18777
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/attention.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Attention)")
+ else:
+ warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/openlrm/models/encoders/dinov2/layers/block.py b/openlrm/models/encoders/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5b50118c1579fd30cda0c2d60b95c85eb04204
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/block.py
@@ -0,0 +1,296 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+# ******************************************************************************
+# Code modified by Zexin He in 2023-2024.
+# Modifications are marked with clearly visible comments
+# licensed under the Apache License, Version 2.0.
+# ******************************************************************************
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (Block)")
+ else:
+ warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+
+ warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+# ********** Modified by Zexin He in 2023-2024 **********
+# Override forward with modulation input
+class BlockWithModulation(Block):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x: Tensor, mod: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor, mod: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x, mod)))
+
+ def ffn_residual_func(x: Tensor, mod: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x, mod)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ raise NotImplementedError("Modulation with drop path ratio larger than 0.1 is not supported yet")
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, mod))
+ x = x + self.drop_path1(ffn_residual_func(x, mod)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, mod)
+ x = x + ffn_residual_func(x, mod)
+ return x
+# ********************************************************
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+
+ # ********** Modified by Zexin He in 2023-2024 **********
+ warnings.warn("NestedTensorBlock is deprecated for now!", DeprecationWarning)
+ # ********************************************************
+
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/openlrm/models/encoders/dinov2/layers/dino_head.py b/openlrm/models/encoders/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/openlrm/models/encoders/dinov2/layers/drop_path.py b/openlrm/models/encoders/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/openlrm/models/encoders/dinov2/layers/layer_scale.py b/openlrm/models/encoders/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/openlrm/models/encoders/dinov2/layers/mlp.py b/openlrm/models/encoders/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/openlrm/models/encoders/dinov2/layers/patch_embed.py b/openlrm/models/encoders/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/openlrm/models/encoders/dinov2/layers/swiglu_ffn.py b/openlrm/models/encoders/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/openlrm/models/encoders/dinov2/models/__init__.py b/openlrm/models/encoders/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/models/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from . import vision_transformer as vits
+
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
diff --git a/openlrm/models/encoders/dinov2/models/vision_transformer.py b/openlrm/models/encoders/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c90ac2be1fe294a0db6080cd24155629083d3ec9
--- /dev/null
+++ b/openlrm/models/encoders/dinov2/models/vision_transformer.py
@@ -0,0 +1,443 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+# ******************************************************************************
+# Code modified by Zexin He in 2023-2024.
+# Modifications are marked with clearly visible comments
+# licensed under the Apache License, Version 2.0.
+# ******************************************************************************
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+# ********** Modified by Zexin He in 2023-2024 **********
+# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, Block, BlockWithModulation
+# ********************************************************
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ # ********** Modified by Zexin He in 2023-2024 **********
+ modulation_dim: int = None,
+ # ********************************************************
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+
+ # ********** Modified by Zexin He in 2023-2024 **********
+ block_norm_layer = None
+ if modulation_dim is not None:
+ from ....modulate import ModLN
+ block_norm_layer = partial(ModLN, mod_dim=modulation_dim)
+ else:
+ block_norm_layer = nn.LayerNorm
+ block_norm_layer = partial(block_norm_layer, eps=1e-6)
+ # ********************************************************
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ # ********** Modified by Zexin He in 2023-2024 **********
+ norm_layer=block_norm_layer,
+ # ********************************************************
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ # ********** Modified by Zexin He in 2023-2024 **********
+ # hacking unused mask_token for better DDP
+ # self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+ # ********************************************************
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ # ********** Modified by Zexin He in 2023-2024 **********
+ raise NotImplementedError("Masking is not supported in hacked DINOv2")
+ # x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+ # ********************************************************
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ # ********** Modified by Zexin He in 2023-2024 **********
+ def forward_features(self, x, masks=None, mod=None):
+ if isinstance(x, list):
+ raise DeprecationWarning("forward_features_list is deprecated, use forward_features")
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ if mod is None:
+ for blk in self.blocks:
+ x = blk(x)
+ else:
+ for blk in self.blocks:
+ x = blk(x, mod)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ # ********************************************************
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+# ********** Modified by Zexin He in 2023-2024 **********
+# block class selected from Block and BlockWithModulation
+
+def _block_cls(**kwargs):
+ modulation_dim = kwargs.get("modulation_dim", None)
+ if modulation_dim is None:
+ block_cls = Block
+ else:
+ block_cls = BlockWithModulation
+ return block_cls
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+# ********************************************************
diff --git a/openlrm/models/encoders/dinov2_wrapper.py b/openlrm/models/encoders/dinov2_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cb463ed812c4b92b20d7c49b0219eecd607c607
--- /dev/null
+++ b/openlrm/models/encoders/dinov2_wrapper.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class Dinov2Wrapper(nn.Module):
+ """
+ Dino v2 wrapper using original implementation, hacked with modulation.
+ """
+ def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True):
+ super().__init__()
+ self.modulation_dim = modulation_dim
+ self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim)
+ if freeze:
+ if modulation_dim is not None:
+ raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.")
+ self._freeze()
+
+ def _freeze(self):
+ logger.warning(f"======== Freezing Dinov2Wrapper ========")
+ self.model.eval()
+ for name, param in self.model.named_parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True):
+ from importlib import import_module
+ dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__)
+ model_fn = getattr(dinov2_hub, model_name)
+ logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.")
+ model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained)
+ return model
+
+ @torch.compile
+ def forward(self, image: torch.Tensor, mod: torch.Tensor = None):
+ # image: [N, C, H, W]
+ # mod: [N, D] or None
+ # RGB image with [0,1] scale and properly sized
+ if self.modulation_dim is None:
+ assert mod is None, "Unexpected modulation input in dinov2 forward."
+ outs = self.model(image, is_training=True)
+ else:
+ assert mod is not None, "Modulation input is required in modulated dinov2 forward."
+ outs = self.model(image, mod=mod, is_training=True)
+ ret = torch.cat([
+ outs["x_norm_clstoken"].unsqueeze(dim=1),
+ outs["x_norm_patchtokens"],
+ ], dim=1)
+ return ret
diff --git a/openlrm/models/modeling_lrm.py b/openlrm/models/modeling_lrm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f865937cb7cf51f692297b726075af36bbcaab3e
--- /dev/null
+++ b/openlrm/models/modeling_lrm.py
@@ -0,0 +1,245 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+
+from .embedder import CameraEmbedder
+from .transformer import TransformerDecoder
+from .rendering.synthesizer import TriplaneSynthesizer
+from .utils import zero_module
+import loratorch as lora
+from .swin_transformer import CrossAttentionLayer
+
+logger = get_logger(__name__)
+
+
+class ModelLRM(nn.Module):
+ """
+ Full model of the basic single-view large reconstruction model.
+ """
+ def __init__(self, camera_embed_dim: int, rendering_samples_per_ray: int,
+ transformer_dim: int, transformer_layers: int, transformer_heads: int,
+ triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
+ encoder_freeze: bool = True, encoder_type: str = 'dino',
+ encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768,
+ model_lora_rank: int = 0, conv_fuse=False,
+ swin_ca_fuse=False, ca_dim=32, ca_depth=2, ca_num_heads=8, ca_window_size=2):
+ super().__init__()
+
+ # attributes
+ self.encoder_feat_dim = encoder_feat_dim
+ self.camera_embed_dim = camera_embed_dim
+ self.triplane_low_res = triplane_low_res
+ self.triplane_high_res = triplane_high_res
+ self.triplane_dim = triplane_dim
+
+ self.conv_fuse = conv_fuse
+ self.swin_ca_fuse = swin_ca_fuse
+
+ # modules
+ self.encoder = self._encoder_fn(encoder_type)(
+ model_name=encoder_model_name,
+ freeze=encoder_freeze,
+ )
+ self.camera_embedder = CameraEmbedder(
+ raw_dim=12+4, embed_dim=camera_embed_dim,
+ )
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5)
+ if model_lora_rank > 0:
+ self.transformer = TransformerDecoder(
+ block_type='cond_mod',
+ num_layers=transformer_layers, num_heads=transformer_heads,
+ inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=camera_embed_dim,
+ lora_rank=model_lora_rank
+ )
+ lora.mark_only_lora_as_trainable(self.transformer)
+ else:
+ self.transformer = TransformerDecoder(
+ block_type='cond_mod',
+ num_layers=transformer_layers, num_heads=transformer_heads,
+ inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=camera_embed_dim,
+ )
+ self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
+ self.synthesizer = TriplaneSynthesizer(
+ triplane_dim=triplane_dim, samples_per_ray=rendering_samples_per_ray,
+ )
+
+ if model_lora_rank > 0:
+ if self.conv_fuse:
+ # self.front_back_conv = nn.Conv2d(in_channels=triplane_dim*2, out_channels=triplane_dim, kernel_size=(3, 3), stride=(1, 1), padding=1)
+ # zero_module(self.front_back_conv)
+ self.front_back_conv = nn.ModuleList([
+ nn.Conv2d(in_channels=triplane_dim*2, out_channels=triplane_dim*4, kernel_size=(3, 3), stride=(1, 1), padding=1),
+ nn.LayerNorm([triplane_dim*4, triplane_high_res, triplane_high_res]), # Using Layer Normalization
+ nn.GELU(), # Using GELU activation
+ nn.Conv2d(in_channels=triplane_dim*4, out_channels=triplane_dim*4, kernel_size=(3, 3), stride=(1, 1), padding=1),
+ nn.LayerNorm([triplane_dim*4, triplane_high_res, triplane_high_res]), # Using Layer Normalization
+ nn.GELU(), # Using GELU activation
+ nn.Conv2d(in_channels=triplane_dim*4, out_channels=triplane_dim, kernel_size=(3, 3), stride=(1, 1), padding=1)
+ ])
+ self.freeze_modules(encoder=True, camera_embedder=True,
+ pos_embed=False, transformer=False, upsampler=False,
+ synthesizer=False)
+ elif self.swin_ca_fuse:
+ self.swin_cross_attention = CrossAttentionLayer(dim=ca_dim, depth=ca_depth, num_heads=ca_num_heads, window_size=ca_window_size)
+ self.freeze_modules(encoder=True, camera_embedder=True,
+ pos_embed=False, transformer=False, upsampler=False,
+ synthesizer=False)
+ else:
+ raise ValueError("You need to specify a method for fusing the front and the back.")
+
+
+ def freeze_modules(self, encoder=False, camera_embedder=False,
+ pos_embed=False, transformer=False, upsampler=False,
+ synthesizer=False):
+ """
+ Freeze specified modules
+ """
+ if encoder:
+ for param in self.encoder.parameters():
+ param.requires_grad = False
+ if camera_embedder:
+ for param in self.camera_embedder.parameters():
+ param.requires_grad = False
+ if pos_embed:
+ for param in self.pos_embed.parameters():
+ param.requires_grad = False
+ if transformer:
+ for param in self.transformer.parameters():
+ param.requires_grad = False
+ if upsampler:
+ for param in self.upsampler.parameters():
+ param.requires_grad = False
+ if synthesizer:
+ for param in self.synthesizer.parameters():
+ param.requires_grad = False
+
+ @staticmethod
+ def _encoder_fn(encoder_type: str):
+ encoder_type = encoder_type.lower()
+ assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type"
+ if encoder_type == 'dino':
+ from .encoders.dino_wrapper import DinoWrapper
+ logger.info("Using DINO as the encoder")
+ return DinoWrapper
+ elif encoder_type == 'dinov2':
+ from .encoders.dinov2_wrapper import Dinov2Wrapper
+ logger.info("Using DINOv2 as the encoder")
+ return Dinov2Wrapper
+
+ def forward_transformer(self, image_feats, camera_embeddings):
+ assert image_feats.shape[0] == camera_embeddings.shape[0], \
+ "Batch size mismatch for image_feats and camera_embeddings!"
+ N = image_feats.shape[0]
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
+ x = self.transformer(
+ x,
+ cond=image_feats,
+ mod=camera_embeddings,
+ )
+ return x
+
+ def reshape_upsample(self, tokens):
+ N = tokens.shape[0]
+ H = W = self.triplane_low_res
+ x = tokens.view(N, 3, H, W, -1)
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
+ x = self.upsampler(x) # [3*N, D', H', W']
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
+ x = x.contiguous()
+ return x
+
+ @torch.compile
+ def forward_planes(self, image, camera):
+ # image: [N, C_img, H_img, W_img]
+ # camera: [N, D_cam_raw]
+ N = image.shape[0]
+
+ # encode image
+ image_feats = self.encoder(image)
+ assert image_feats.shape[-1] == self.encoder_feat_dim, \
+ f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
+
+ # embed camera
+ camera_embeddings = self.camera_embedder(camera)
+ assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
+ f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
+
+ # transformer generating planes
+ tokens = self.forward_transformer(image_feats, camera_embeddings)
+ planes = self.reshape_upsample(tokens)
+ assert planes.shape[0] == N, "Batch size mismatch for planes"
+ assert planes.shape[1] == 3, "Planes should have 3 channels"
+
+ return planes
+
+ def forward(self, image, source_camera, render_cameras, render_anchors, render_resolutions, render_bg_colors, render_region_size: int,
+ image_back=None,):
+ # image: [N, C_img, H_img, W_img]
+ # source_camera: [N, D_cam_raw]
+ # render_cameras: [N, M, D_cam_render]
+ # render_anchors: [N, M, 2]
+ # render_resolutions: [N, M, 1]
+ # render_bg_colors: [N, M, 1]
+ # render_region_size: int
+ assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera"
+ assert image.shape[0] == render_cameras.shape[0], "Batch size mismatch for image and render_cameras"
+ assert image.shape[0] == render_anchors.shape[0], "Batch size mismatch for image and render_anchors"
+ assert image.shape[0] == render_bg_colors.shape[0], "Batch size mismatch for image and render_bg_colors"
+ N, M = render_cameras.shape[:2]
+
+ if image_back is not None:
+ front_planes = self.forward_planes(image, source_camera)
+ back_planes = self.forward_planes(image_back, source_camera)
+
+ # XY Plane
+ back_planes[:, 0, :, :, :] = torch.flip(back_planes[:, 0, :, :, :], dims=[-2, -1])
+ # XZ Plane
+ back_planes[:, 1, :, :, :] = torch.flip(back_planes[:, 1, :, :, :], dims=[-1])
+ # YZ Plane
+ back_planes[:, 2, :, :, :] = torch.flip(back_planes[:, 2, :, :, :], dims=[-2])
+
+ # To fuse the front planes and the back planes
+ bs, num_planes, channels, height, width = front_planes.shape
+ if self.conv_fuse:
+ planes = torch.cat((front_planes, back_planes), dim=2)
+ planes = planes.reshape(-1, channels*2, height, width)
+ # Apply multiple convolutional layers
+ for layer in self.front_back_conv:
+ planes = layer(planes)
+
+ planes = planes.view(bs, num_planes, -1, height, width)
+ # planes = self.front_back_conv(planes).view(bs, num_planes, -1, height, width) # only one layer.
+ elif self.swin_ca_fuse:
+ front_planes = front_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() # [8, 3, 32, 64, 64] -> [24, 32, 4096] -> [24, 4096, 32]
+ back_planes = back_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous()
+ planes = self.swin_cross_attention(front_planes, back_planes, height, width)[0].permute(0, 2, 1).reshape(bs, num_planes, channels, height, width)
+ else:
+ planes = self.forward_planes(image, source_camera)
+
+ # render target views
+ render_results = self.synthesizer(planes, render_cameras, render_anchors, render_resolutions, render_bg_colors, render_region_size)
+ assert render_results['images_rgb'].shape[0] == N, "Batch size mismatch for render_results"
+ assert render_results['images_rgb'].shape[1] == M, "Number of rendered views should be consistent with render_cameras"
+
+ return {
+ 'planes': planes,
+ **render_results,
+ }
diff --git a/openlrm/models/modulate.py b/openlrm/models/modulate.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d2a0f0240cc1d596a9a544d56eac5ee7e03cc7d
--- /dev/null
+++ b/openlrm/models/modulate.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+
+
+class ModLN(nn.Module):
+ """
+ Modulation with adaLN.
+
+ References:
+ DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
+ """
+ def __init__(self, inner_dim: int, mod_dim: int, eps: float):
+ super().__init__()
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
+ self.mlp = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(mod_dim, inner_dim * 2),
+ )
+
+ @staticmethod
+ def modulate(x, shift, scale):
+ # x: [N, L, D]
+ # shift, scale: [N, D]
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+ def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
+ shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D]
+ return self.modulate(self.norm(x), shift, scale) # [N, L, D]
diff --git a/openlrm/models/rendering/__init__.py b/openlrm/models/rendering/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/openlrm/models/rendering/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/openlrm/models/rendering/synthesizer.py b/openlrm/models/rendering/synthesizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3833cfeb681e1a8f61d8244a36d03c6631d8b3de
--- /dev/null
+++ b/openlrm/models/rendering/synthesizer.py
@@ -0,0 +1,208 @@
+# ORIGINAL LICENSE
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# Modified by Zexin He in 2023-2024.
+# The modifications are subject to the same license as the original.
+
+
+import itertools
+import torch
+import torch.nn as nn
+
+from .utils.renderer import ImportanceRenderer
+from .utils.ray_sampler import RaySampler
+
+
+class ShiftedSoftplus(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return nn.functional.softplus(x - 1)
+
+
+class OSGDecoder(nn.Module):
+ """
+ Triplane decoder that gives RGB and sigma values from sampled features.
+ Using ReLU here instead of Softplus in the original implementation.
+
+ Reference:
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
+ """
+ def __init__(self, n_features: int,
+ hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(3 * n_features, hidden_dim),
+ activation(),
+ *itertools.chain(*[[
+ nn.Linear(hidden_dim, hidden_dim),
+ activation(),
+ ] for _ in range(num_layers - 2)]),
+ nn.Linear(hidden_dim, 1 + 3),
+ )
+ # init all bias to zero
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.zeros_(m.bias)
+
+ @torch.compile
+ def forward(self, sampled_features, ray_directions):
+ # Aggregate features by mean
+ # sampled_features = sampled_features.mean(1)
+ # Aggregate features by concatenation
+ _N, n_planes, _M, _C = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+ x = sampled_features
+
+ N, M, C = x.shape
+ x = x.contiguous().view(N*M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ sigma = x[..., 0:1]
+
+ return {'rgb': rgb, 'sigma': sigma}
+
+
+class TriplaneSynthesizer(nn.Module):
+ """
+ Synthesizer that renders a triplane volume with planes and a camera.
+
+ Reference:
+ EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
+ """
+
+ DEFAULT_RENDERING_KWARGS = {
+ 'ray_start': 'auto',
+ 'ray_end': 'auto',
+ 'box_warp': 2.,
+ 'white_back': False,
+ 'disparity_space_sampling': False,
+ 'clamp_mode': 'softplus',
+ 'sampler_bbox_min': -1.,
+ 'sampler_bbox_max': 1.,
+ }
+
+ def __init__(self, triplane_dim: int, samples_per_ray: int):
+ super().__init__()
+
+ # attributes
+ self.triplane_dim = triplane_dim
+ self.rendering_kwargs = {
+ **self.DEFAULT_RENDERING_KWARGS,
+ 'depth_resolution': samples_per_ray // 2,
+ 'depth_resolution_importance': samples_per_ray // 2,
+ }
+
+ # renderings
+ self.renderer = ImportanceRenderer()
+ self.ray_sampler = RaySampler()
+
+ # modules
+ self.decoder = OSGDecoder(n_features=triplane_dim)
+
+ def forward(self, planes, cameras, anchors, resolutions, bg_colors, region_size: int):
+ # planes: (N, 3, D', H', W')
+ # cameras: (N, M, D_cam)
+ # anchors: (N, M, 2)
+ # resolutions: (N, M, 1)
+ # bg_colors: (N, M, 1)
+ # region_size: int
+ assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
+ assert planes.shape[0] == anchors.shape[0], "Batch size mismatch for planes and anchors"
+ assert cameras.shape[1] == anchors.shape[1], "Number of views mismatch for cameras and anchors"
+ N, M = cameras.shape[:2]
+
+ cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
+ intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
+
+ # Create a batch of rays for volume rendering
+ ray_origins, ray_directions = self.ray_sampler(
+ cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
+ intrinsics=intrinsics.reshape(-1, 3, 3),
+ resolutions=resolutions.reshape(-1, 1),
+ anchors=anchors.reshape(-1, 2),
+ region_size=region_size,
+ )
+ assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
+ assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
+
+ # Perform volume rendering
+ rgb_samples, depth_samples, weights_samples = self.renderer(
+ planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
+ bg_colors=bg_colors.reshape(-1, 1),
+ )
+
+ # Reshape into 'raw' neural-rendered image
+ Himg = Wimg = region_size
+ rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
+ depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
+ weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
+
+ return {
+ 'images_rgb': rgb_images,
+ 'images_depth': depth_images,
+ 'images_weight': weight_images,
+ }
+
+ def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
+ # planes: (N, 3, D', H', W')
+ # grid_size: int
+ # aabb: (N, 2, 3)
+ if aabb is None:
+ aabb = torch.tensor([
+ [self.rendering_kwargs['sampler_bbox_min']] * 3,
+ [self.rendering_kwargs['sampler_bbox_max']] * 3,
+ ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
+ assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
+ N = planes.shape[0]
+
+ # create grid points for triplane query
+ grid_points = []
+ for i in range(N):
+ grid_points.append(torch.stack(torch.meshgrid(
+ torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
+ torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
+ torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
+ indexing='ij',
+ ), dim=-1).reshape(-1, 3))
+ cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
+
+ features = self.forward_points(planes, cube_grid)
+
+ # reshape into grid
+ features = {
+ k: v.reshape(N, grid_size, grid_size, grid_size, -1)
+ for k, v in features.items()
+ }
+ return features
+
+ def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
+ # planes: (N, 3, D', H', W')
+ # points: (N, P, 3)
+ N, P = points.shape[:2]
+
+ # query triplane in chunks
+ outs = []
+ for i in range(0, points.shape[1], chunk_size):
+ chunk_points = points[:, i:i+chunk_size]
+
+ # query triplane
+ chunk_out = self.renderer.run_model_activated(
+ planes=planes,
+ decoder=self.decoder,
+ sample_coordinates=chunk_points,
+ sample_directions=torch.zeros_like(chunk_points),
+ options=self.rendering_kwargs,
+ )
+ outs.append(chunk_out)
+
+ # concatenate the outputs
+ point_features = {
+ k: torch.cat([out[k] for out in outs], dim=1)
+ for k in outs[0].keys()
+ }
+ return point_features
diff --git a/openlrm/models/rendering/utils/__init__.py b/openlrm/models/rendering/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34
--- /dev/null
+++ b/openlrm/models/rendering/utils/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
diff --git a/openlrm/models/rendering/utils/math_utils.py b/openlrm/models/rendering/utils/math_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cf9d2b811e0acbc7923bc9126e010b52cb1a8af
--- /dev/null
+++ b/openlrm/models/rendering/utils/math_utils.py
@@ -0,0 +1,118 @@
+# MIT License
+
+# Copyright (c) 2022 Petr Kellnhofer
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import torch
+
+def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
+ """
+ Left-multiplies MxM @ NxM. Returns NxM.
+ """
+ res = torch.matmul(vectors4, matrix.T)
+ return res
+
+
+def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize vector lengths.
+ """
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
+
+def torch_dot(x: torch.Tensor, y: torch.Tensor):
+ """
+ Dot product of two tensors.
+ """
+ return (x * y).sum(-1)
+
+
+def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
+ """
+ Author: Petr Kellnhofer
+ Intersects rays with the [-1, 1] NDC volume.
+ Returns min and max distance of entry.
+ Returns -1 for no intersection.
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
+ """
+ o_shape = rays_o.shape
+ rays_o = rays_o.detach().reshape(-1, 3)
+ rays_d = rays_d.detach().reshape(-1, 3)
+
+
+ bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
+ bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
+ bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
+
+ # Precompute inverse for stability.
+ invdir = 1 / rays_d
+ sign = (invdir < 0).long()
+
+ # Intersect with YZ plane.
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+
+ # Intersect with XZ plane.
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+ tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tymin)
+ tmax = torch.min(tmax, tymax)
+
+ # Intersect with XY plane.
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+ tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tzmin)
+ tmax = torch.min(tmax, tzmax)
+
+ # Mark invalid.
+ tmin[torch.logical_not(is_valid)] = -1
+ tmax[torch.logical_not(is_valid)] = -2
+
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
+
+
+def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
+ """
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
+ """
+ # create a tensor of 'num' steps from 0 to 1
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
+
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
+ for i in range(start.ndim):
+ steps = steps.unsqueeze(-1)
+
+ # the output starts at 'start' and increments until 'stop' in each dimension
+ out = start[None] + steps * (stop - start)[None]
+
+ return out
diff --git a/openlrm/models/rendering/utils/ray_marcher.py b/openlrm/models/rendering/utils/ray_marcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c686c196e043f44e2276f16b4a32e596c802e40
--- /dev/null
+++ b/openlrm/models/rendering/utils/ray_marcher.py
@@ -0,0 +1,68 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Zexin He in 2023-2024.
+# The modifications are subject to the same license as the original.
+
+
+"""
+The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
+Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
+"""
+
+import torch
+import torch.nn as nn
+
+
+class MipRayMarcher2(nn.Module):
+ def __init__(self, activation_factory):
+ super().__init__()
+ self.activation_factory = activation_factory
+
+ def run_forward(self, colors, densities, depths, rendering_options, bg_colors=None):
+ deltas = depths[:, :, 1:] - depths[:, :, :-1]
+ colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
+ densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
+ depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
+
+ # using factory mode for better usability
+ densities_mid = self.activation_factory(rendering_options)(densities_mid)
+
+ density_delta = densities_mid * deltas
+
+ alpha = 1 - torch.exp(-density_delta)
+
+ alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
+ weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
+
+ composite_rgb = torch.sum(weights * colors_mid, -2)
+ weight_total = weights.sum(2)
+ composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
+
+ # clip the composite to min/max range of depths
+ composite_depth = torch.nan_to_num(composite_depth, float('inf'))
+ composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
+
+ if rendering_options.get('white_back', False):
+ composite_rgb = composite_rgb + 1 - weight_total
+ else:
+ assert bg_colors is not None, "Must provide bg_colors if white_back is False"
+ composite_rgb = composite_rgb + bg_colors.unsqueeze(-1) * (1 - weight_total)
+
+ # rendered value scale is 0-1, comment out original mipnerf scaling
+ # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
+
+ return composite_rgb, composite_depth, weights
+
+
+ def forward(self, colors, densities, depths, rendering_options, bg_colors=None):
+ composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options, bg_colors=bg_colors)
+
+ return composite_rgb, composite_depth, weights
diff --git a/openlrm/models/rendering/utils/ray_sampler.py b/openlrm/models/rendering/utils/ray_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab9594be66d02df79ec2295dbd064906f748c2c
--- /dev/null
+++ b/openlrm/models/rendering/utils/ray_sampler.py
@@ -0,0 +1,84 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Zexin He in 2023-2024.
+# The modifications are subject to the same license as the original.
+
+
+"""
+The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
+Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
+"""
+
+import torch
+
+class RaySampler(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
+
+ @torch.compile
+ def forward(self, cam2world_matrix, intrinsics, resolutions, anchors, region_size):
+ """
+ Create batches of rays and return origins and directions.
+
+ cam2world_matrix: (N, 4, 4)
+ intrinsics: (N, 3, 3)
+ resolutions: (N, 1)
+ anchors: (N, 2)
+ region_size: int
+
+ ray_origins: (N, M, 3)
+ ray_dirs: (N, M, 2)
+ """
+
+ N, M = cam2world_matrix.shape[0], region_size**2
+ cam_locs_world = cam2world_matrix[:, :3, 3]
+ fx = intrinsics[:, 0, 0]
+ fy = intrinsics[:, 1, 1]
+ cx = intrinsics[:, 0, 2]
+ cy = intrinsics[:, 1, 2]
+ sk = intrinsics[:, 0, 1]
+
+ uv = torch.stack(torch.meshgrid(
+ torch.arange(region_size, dtype=torch.float32, device=cam2world_matrix.device),
+ torch.arange(region_size, dtype=torch.float32, device=cam2world_matrix.device),
+ indexing='ij',
+ ))
+ uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
+ uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
+
+ # anchors are indexed as normal (row, col) but uv is indexed as (x, y)
+ x_cam = (uv[:, :, 0].view(N, -1) + anchors[:, 1].unsqueeze(-1)) * (1./resolutions) + (0.5/resolutions)
+ y_cam = (uv[:, :, 1].view(N, -1) + anchors[:, 0].unsqueeze(-1)) * (1./resolutions) + (0.5/resolutions)
+ z_cam = torch.ones((N, M), device=cam2world_matrix.device)
+
+ x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
+ y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
+
+ cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
+
+ _opencv2blender = torch.tensor([
+ [1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, -1, 0],
+ [0, 0, 0, 1],
+ ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
+
+ cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
+
+ world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
+
+ ray_dirs = world_rel_points - cam_locs_world[:, None, :]
+ ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
+
+ ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
+
+ return ray_origins, ray_dirs
diff --git a/openlrm/models/rendering/utils/renderer.py b/openlrm/models/rendering/utils/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..628fc029877f8e069feb20d3b310ef4692b4a4bc
--- /dev/null
+++ b/openlrm/models/rendering/utils/renderer.py
@@ -0,0 +1,303 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Zexin He in 2023-2024.
+# The modifications are subject to the same license as the original.
+
+
+"""
+The renderer is a module that takes in rays, decides where to sample along each
+ray, and computes pixel colors using the volume rendering equation.
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .ray_marcher import MipRayMarcher2
+from . import math_utils
+
+def generate_planes():
+ """
+ Defines planes by the three vectors that form the "axes" of the
+ plane. Should work with arbitrary number of planes and planes of
+ arbitrary orientation.
+
+ Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
+ """
+ return torch.tensor([[[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]],
+ [[1, 0, 0],
+ [0, 0, 1],
+ [0, 1, 0]],
+ [[0, 0, 1],
+ [0, 1, 0],
+ [1, 0, 0]]], dtype=torch.float32)
+
+def project_onto_planes(planes, coordinates):
+ """
+ Does a projection of a 3D point onto a batch of 2D planes,
+ returning 2D plane coordinates.
+
+ Takes plane axes of shape n_planes, 3, 3
+ # Takes coordinates of shape N, M, 3
+ # returns projections of shape N*n_planes, M, 2
+ """
+ N, M, C = coordinates.shape
+ n_planes, _, _ = planes.shape
+ coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
+ inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
+ projections = torch.bmm(coordinates, inv_planes)
+ return projections[..., :2]
+
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
+ assert padding_mode == 'zeros'
+ N, n_planes, C, H, W = plane_features.shape
+ _, M, _ = coordinates.shape
+ plane_features = plane_features.view(N*n_planes, C, H, W)
+
+ coordinates = (2/box_warp) * coordinates # add specific box bounds
+
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
+ output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+ return output_features
+
+def sample_from_3dgrid(grid, coordinates):
+ """
+ Expects coordinates in shape (batch_size, num_points_per_batch, 3)
+ Expects grid in shape (1, channels, H, W, D)
+ (Also works if grid has batch size)
+ Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
+ """
+ batch_size, n_coords, n_dims = coordinates.shape
+ sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
+ coordinates.reshape(batch_size, 1, 1, -1, n_dims),
+ mode='bilinear', padding_mode='zeros', align_corners=False)
+ N, C, H, W, D = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
+ return sampled_features
+
+class ImportanceRenderer(torch.nn.Module):
+ """
+ Modified original version to filter out-of-box samples as TensoRF does.
+
+ Reference:
+ TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
+ """
+ def __init__(self):
+ super().__init__()
+ self.activation_factory = self._build_activation_factory()
+ self.ray_marcher = MipRayMarcher2(self.activation_factory)
+ self.plane_axes = generate_planes()
+
+ def _build_activation_factory(self):
+ def activation_factory(options: dict):
+ if options['clamp_mode'] == 'softplus':
+ return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
+ else:
+ assert False, "Renderer only supports `clamp_mode`=`softplus`!"
+ return activation_factory
+
+ def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
+ planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
+ """
+ Additional filtering is applied to filter out-of-box samples.
+ Modifications made by Zexin He.
+ """
+
+ # context related variables
+ batch_size, num_rays, samples_per_ray, _ = depths.shape
+ device = depths.device
+
+ # define sample points with depths
+ sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+
+ # filter out-of-box samples
+ mask_inbox = \
+ (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
+ (sample_coordinates <= rendering_options['sampler_bbox_max'])
+ mask_inbox = mask_inbox.all(-1)
+
+ # forward model according to all samples
+ _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
+
+ # set out-of-box samples to zeros(rgb) & -inf(sigma)
+ SAFE_GUARD = 8
+ DATA_TYPE = _out['sigma'].dtype
+ colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
+ densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
+ colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
+
+ # reshape back
+ colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
+ densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
+
+ return colors_pass, densities_pass
+
+ def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, bg_colors=None):
+ # self.plane_axes = self.plane_axes.to(ray_origins.device)
+
+ if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
+ ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
+ is_ray_valid = ray_end > ray_start
+ if torch.any(is_ray_valid).item():
+ ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
+ ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
+ depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+ else:
+ # Create stratified depth samples
+ depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+
+ # Coarse Pass
+ colors_coarse, densities_coarse = self._forward_pass(
+ depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
+ planes=planes, decoder=decoder, rendering_options=rendering_options)
+
+ # Fine Pass
+ N_importance = rendering_options['depth_resolution_importance']
+ if N_importance > 0:
+ _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, bg_colors=bg_colors)
+
+ depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
+
+ colors_fine, densities_fine = self._forward_pass(
+ depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
+ planes=planes, decoder=decoder, rendering_options=rendering_options)
+
+ all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
+ depths_fine, colors_fine, densities_fine)
+
+ # Aggregate
+ rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, bg_colors=bg_colors)
+ else:
+ rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, bg_colors=bg_colors)
+
+ return rgb_final, depth_final, weights.sum(2)
+
+ def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
+ plane_axes = self.plane_axes.to(planes.device)
+ sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
+
+ out = decoder(sampled_features, sample_directions)
+ if options.get('density_noise', 0) > 0:
+ out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
+ return out
+
+ def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
+ out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
+ out['sigma'] = self.activation_factory(options)(out['sigma'])
+ return out
+
+ def sort_samples(self, all_depths, all_colors, all_densities):
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+ return all_depths, all_colors, all_densities
+
+ def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
+ all_depths = torch.cat([depths1, depths2], dim = -2)
+ all_colors = torch.cat([colors1, colors2], dim = -2)
+ all_densities = torch.cat([densities1, densities2], dim = -2)
+
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+
+ return all_depths, all_colors, all_densities
+
+ def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
+ """
+ Return depths of approximately uniformly spaced samples along rays.
+ """
+ N, M, _ = ray_origins.shape
+ if disparity_space_sampling:
+ depths_coarse = torch.linspace(0,
+ 1,
+ depth_resolution,
+ device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = 1/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+ depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
+ else:
+ if type(ray_start) == torch.Tensor:
+ depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
+ depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
+ else:
+ depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+
+ return depths_coarse
+
+ def sample_importance(self, z_vals, weights, N_importance):
+ """
+ Return depths of importance sampled points along rays. See NeRF importance sampling for more.
+ """
+ with torch.no_grad():
+ batch_size, num_rays, samples_per_ray, _ = z_vals.shape
+
+ z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
+ weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
+
+ # smooth weights
+ weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
+ weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
+ weights = weights + 0.01
+
+ z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
+ importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
+ N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
+ return importance_z_vals
+
+ def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
+ """
+ Sample @N_importance samples from @bins with distribution defined by @weights.
+ Inputs:
+ bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
+ weights: (N_rays, N_samples_)
+ N_importance: the number of samples to draw from the distribution
+ det: deterministic or not
+ eps: a small number to prevent division by zero
+ Outputs:
+ samples: the sampled samples
+ """
+ N_rays, N_samples_ = weights.shape
+ weights = weights + eps # prevent division by zero (don't do inplace op!)
+ pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
+ cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
+ cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
+ # padded to 0~1 inclusive
+
+ if det:
+ u = torch.linspace(0, 1, N_importance, device=bins.device)
+ u = u.expand(N_rays, N_importance)
+ else:
+ u = torch.rand(N_rays, N_importance, device=bins.device)
+ u = u.contiguous()
+
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.clamp_min(inds-1, 0)
+ above = torch.clamp_max(inds, N_samples_)
+
+ inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
+ cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
+ bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
+
+ denom = cdf_g[...,1]-cdf_g[...,0]
+ denom[denom [B, H//Mh, W//Mh, Mw, Mw, C]
+ # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size: int, H: int, W: int):
+ """
+ Restore each window into a feature map.
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size(M)
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
+ # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop)
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop2 = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class WindowCrossAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # [Mh, Mw]
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw]
+ coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw]
+ # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, kv, mask: Optional[torch.Tensor] = None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, Mh*Mw, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ # [batch_size*num_windows, Mh*Mw, total_embed_dim]
+ B_, N, C = x.shape
+ # q(): -> [batch_size*num_windows, Mh*Mw, 1*total_embed_dim]
+ # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
+ # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
+ q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
+
+ kv = self.kv(kv).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
+ # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw]
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ # mask: [nW, Mh*Mw, Mh*Mw]
+ nW = mask.shape[0] # num_windows
+ # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
+ # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
+ # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
+ # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerCABlock(nn.Module):
+ r""" Swin Transformer Cross Attention Block.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowCrossAttention(
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
+ attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x, kv, attn_mask):
+ H, W = self.H, self.W
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ kv = self.norm1(kv)
+ kv = kv.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ # Pad the feature map to multiples of the window size.
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ kv = F.pad(kv, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ shifted_kv = torch.roll(kv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+ shifted_kv = kv
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]
+
+ kv_windows = window_partition(shifted_kv, self.window_size) # [nW*B, Mh, Mw, C]
+ kv_windows = kv_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, kv_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C]
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C]
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ # Remove the padded data from the front.
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class CrossAttentionLayer(nn.Module):
+ def __init__(self, dim, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm,):
+ super().__init__()
+ self.dim = dim
+ self.depth = depth
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerCABlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else self.shift_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ def create_mask(self, x, H, W):
+ # calculate attention mask for SW-MSA
+ # Ensure that Hp and Wp are multiples of window_size.
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ # Have the same channel arrangement as the feature map for ease of subsequent window_partition.
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1]
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
+ # [nW, Mh*Mw, Mh*Mw]
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ return attn_mask
+
+ def forward(self, x, kv, H, W):
+ attn_mask = self.create_mask(x, H, W) # [nW, Mh*Mw, Mh*Mw]
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ x = blk(x, kv, attn_mask)
+ return x, H, W
+
+
+
+if __name__ == '__main__':
+
+ shape = [8, 3, 32, 64, 64]
+ tensor = torch.zeros(shape)
+ _, _, _, H, W = tensor.shape
+ front_plane = tensor.reshape(-1, 32, 64*64).permute(0, 2,1).contiguous()
+
+ back_plane = torch.zeros(front_plane.shape)
+
+ model = CrossAttentionLayer(
+ dim=32,
+ depth=2,
+ num_heads=8,
+ window_size=2,
+ )
+
+ output = model(front_plane, back_plane, H, W)
+
diff --git a/openlrm/models/transformer.py b/openlrm/models/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..09217957ca643bb07283534ac58139360bda3524
--- /dev/null
+++ b/openlrm/models/transformer.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from functools import partial
+import torch
+import torch.nn as nn
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class TransformerDecoder(nn.Module):
+
+ """
+ Transformer blocks that process the input and optionally use condition and modulation.
+ """
+
+ def __init__(self, block_type: str,
+ num_layers: int, num_heads: int,
+ inner_dim: int, cond_dim: int = None, mod_dim: int = None,
+ eps: float = 1e-6,
+ lora_rank: int = 0):
+ super().__init__()
+ self.block_type = block_type
+ self.layers = nn.ModuleList([
+ self._block_fn(inner_dim, cond_dim, mod_dim, lora_rank=lora_rank)(
+ num_heads=num_heads,
+ eps=eps,
+ )
+ for _ in range(num_layers)
+ ])
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
+
+ @property
+ def block_type(self):
+ return self._block_type
+
+ @block_type.setter
+ def block_type(self, block_type):
+ assert block_type in ['basic', 'cond', 'mod', 'cond_mod'], \
+ f"Unsupported block type: {block_type}"
+ self._block_type = block_type
+
+ def _block_fn(self, inner_dim, cond_dim, mod_dim, lora_rank=0):
+ assert inner_dim is not None, f"inner_dim must always be specified"
+ if self.block_type == 'basic':
+ assert cond_dim is None and mod_dim is None, \
+ f"Condition and modulation are not supported for BasicBlock"
+ from .block import BasicBlock
+ logger.debug(f"Using BasicBlock")
+ return partial(BasicBlock, inner_dim=inner_dim)
+ elif self.block_type == 'cond':
+ assert cond_dim is not None, f"Condition dimension must be specified for ConditionBlock"
+ assert mod_dim is None, f"Modulation dimension is not supported for ConditionBlock"
+ from .block import ConditionBlock
+ logger.debug(f"Using ConditionBlock")
+ return partial(ConditionBlock, inner_dim=inner_dim, cond_dim=cond_dim)
+ elif self.block_type == 'mod':
+ logger.error(f"modulation without condition is not implemented")
+ raise NotImplementedError(f"modulation without condition is not implemented")
+ elif self.block_type == 'cond_mod':
+ assert cond_dim is not None and mod_dim is not None, \
+ f"Condition and modulation dimensions must be specified for ConditionModulationBlock"
+ from .block import ConditionModulationBlock
+ logger.debug(f"Using ConditionModulationBlock")
+ return partial(ConditionModulationBlock, inner_dim=inner_dim, cond_dim=cond_dim, mod_dim=mod_dim, lora_rank=lora_rank)
+ else:
+ raise ValueError(f"Unsupported block type during runtime: {self.block_type}")
+
+ def assert_runtime_integrity(self, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor):
+ assert x is not None, f"Input tensor must be specified"
+ if self.block_type == 'basic':
+ assert cond is None and mod is None, \
+ f"Condition and modulation are not supported for BasicBlock"
+ elif self.block_type == 'cond':
+ assert cond is not None and mod is None, \
+ f"Condition must be specified and modulation is not supported for ConditionBlock"
+ elif self.block_type == 'mod':
+ raise NotImplementedError(f"modulation without condition is not implemented")
+ else:
+ assert cond is not None and mod is not None, \
+ f"Condition and modulation must be specified for ConditionModulationBlock"
+
+ def forward_layer(self, layer: nn.Module, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor):
+ if self.block_type == 'basic':
+ return layer(x)
+ elif self.block_type == 'cond':
+ return layer(x, cond)
+ elif self.block_type == 'mod':
+ return layer(x, mod)
+ else:
+ return layer(x, cond, mod)
+
+ def forward(self, x: torch.Tensor, cond: torch.Tensor = None, mod: torch.Tensor = None):
+ # x: [N, L, D]
+ # cond: [N, L_cond, D_cond] or None
+ # mod: [N, D_mod] or None
+ self.assert_runtime_integrity(x, cond, mod)
+ for layer in self.layers:
+ x = self.forward_layer(layer, x, cond, mod)
+ x = self.norm(x)
+ return x
diff --git a/openlrm/models/utils.py b/openlrm/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..682f75a42e35110f9847fae849240662dc1f08e6
--- /dev/null
+++ b/openlrm/models/utils.py
@@ -0,0 +1,54 @@
+'''
+This is to save and load the model.
+'''
+
+def check_model_checkpoint_consistency(ckpt_state_dict, model_state_dict, special_strs=None):
+ """
+ Maintain all checkpoint keys. Ignore keys with specific endings if absent.
+ Raise exception for model keys not in checkpoint unless ignored.
+ ckpt: The state dictionary of the checkpoint.
+ model_state_dict: The state dictionary of the model.
+ special_endings: A list of specific endings of strings to be ignored.
+ """
+ filtered_ckpt = {}
+ special_modules =[]
+ for key in model_state_dict.keys():
+ if key in ckpt_state_dict:
+ filtered_ckpt[key] = ckpt_state_dict[key]
+ elif any(special_str in key for special_str in special_strs):
+ special_modules.append(key)
+ continue
+ else:
+ raise KeyError(f"Key '{key}' not found in checkpoint and does not match any special endings.")
+
+def remove_module_prefix(state_dict):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if key.startswith('module.'):
+ new_key = key[len('module.'):]
+ new_state_dict[new_key] = value
+ else:
+ new_state_dict[key] = value
+ return new_state_dict
+
+
+
+# This is for reducing impact at the beginning of training.
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def filter_model_checkpoint(ckpt_state_dict, model_state_dict, need_strs=None):
+ filtered_ckpt = {}
+ for key in model_state_dict.keys():
+ if key in ckpt_state_dict and any(need_str in key for need_str in need_strs):
+ filtered_ckpt[key] = ckpt_state_dict[key]
+ else:
+ continue
+
+ return filtered_ckpt
diff --git a/openlrm/runners/__init__.py b/openlrm/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b01c77e2357b2bdec2bb1bc9866d54d8d5e58ed
--- /dev/null
+++ b/openlrm/runners/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from openlrm.utils.registry import Registry
+
+REGISTRY_RUNNERS = Registry()
+
+from .train import *
+from .infer import *
diff --git a/openlrm/runners/abstract.py b/openlrm/runners/abstract.py
new file mode 100644
index 0000000000000000000000000000000000000000..76916e805a5cfbf333d2d63e8607811939a5a639
--- /dev/null
+++ b/openlrm/runners/abstract.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+
+
+class Runner(ABC):
+ """Abstract runner class"""
+
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def run(self):
+ pass
diff --git a/openlrm/runners/infer/__init__.py b/openlrm/runners/infer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bf897a0d2a887cdd53f0c67016f7ab197a3788b
--- /dev/null
+++ b/openlrm/runners/infer/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .lrm import LRMInferrer
diff --git a/openlrm/runners/infer/base_inferrer.py b/openlrm/runners/infer/base_inferrer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2725152266dbf52a13bf308a0df0e77e09aa2542
--- /dev/null
+++ b/openlrm/runners/infer/base_inferrer.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from abc import abstractmethod
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+
+from openlrm.runners.abstract import Runner
+
+
+logger = get_logger(__name__)
+
+
+class Inferrer(Runner):
+
+ EXP_TYPE: str = None
+
+ def __init__(self):
+ super().__init__()
+
+ torch._dynamo.config.disable = True
+ self.accelerator = Accelerator()
+
+ self.model : torch.nn.Module = None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ @property
+ def device(self):
+ return self.accelerator.device
+
+ @abstractmethod
+ def _build_model(self, cfg):
+ pass
+
+ @abstractmethod
+ def infer_single(self, *args, **kwargs):
+ pass
+
+ @abstractmethod
+ def infer(self):
+ pass
+
+ def run(self):
+ self.infer()
diff --git a/openlrm/runners/infer/lrm.py b/openlrm/runners/infer/lrm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f90e19bc245638ee7ef6028f1c0f5e97dca7140
--- /dev/null
+++ b/openlrm/runners/infer/lrm.py
@@ -0,0 +1,403 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import os
+import argparse
+import mcubes
+import trimesh
+import safetensors
+import numpy as np
+from PIL import Image
+from omegaconf import OmegaConf
+from tqdm.auto import tqdm
+from accelerate.logging import get_logger
+from huggingface_hub import hf_hub_download
+
+from .base_inferrer import Inferrer
+from openlrm.datasets.cam_utils import build_camera_principle, build_camera_standard, surrounding_views_linspace, create_intrinsics
+from openlrm.utils.logging import configure_logger
+from openlrm.runners import REGISTRY_RUNNERS
+from openlrm.utils.video import images_to_video
+from openlrm.utils.hf_hub import wrap_model_hub
+
+
+logger = get_logger(__name__)
+
+
+def parse_configs():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str)
+ parser.add_argument('--infer', type=str)
+ args, unknown = parser.parse_known_args()
+
+ cfg = OmegaConf.create()
+ cli_cfg = OmegaConf.from_cli(unknown)
+
+ # parse from ENV
+ if os.environ.get('APP_INFER') is not None:
+ args.infer = os.environ.get('APP_INFER')
+ if os.environ.get('APP_MODEL_NAME') is not None:
+ cli_cfg.model_name = os.environ.get('APP_MODEL_NAME')
+ if os.environ.get('APP_PRETRAIN_MODEL_NAME') is not None:
+ cli_cfg.pretrain_model_hf = os.environ.get('APP_PRETRAIN_MODEL_NAME')
+
+ if args.config is not None:
+ cfg_train = OmegaConf.load(args.config)
+ cfg.source_size = cfg_train.dataset.source_image_res
+ cfg.render_size = cfg_train.dataset.render_image.high
+ _relative_path = os.path.join(cfg_train.experiment.parent, cfg_train.experiment.child, os.path.basename(cli_cfg.model_name).split('_')[-1])
+ cfg.video_dump = os.path.join("exps", 'videos', _relative_path)
+ cfg.mesh_dump = os.path.join("exps", 'meshes', _relative_path)
+
+ if args.infer is not None:
+ cfg_infer = OmegaConf.load(args.infer)
+ cfg.merge_with(cfg_infer)
+ if hasattr(cfg, 'experiment') and hasattr(cfg.experiment, 'parent'):
+ cfg.setdefault('video_dump', os.path.join("dumps", cli_cfg.model_name, cfg.experiment.parent, cfg.experiment.child, 'videos'))
+ cfg.setdefault('mesh_dump', os.path.join("dumps", cli_cfg.model_name, cfg.experiment.parent, cfg.experiment.child, 'meshes'))
+ else:
+ cfg.setdefault('video_dump', os.path.join("dumps", cli_cfg.model_name, 'videos'))
+ cfg.setdefault('mesh_dump', os.path.join("dumps", cli_cfg.model_name, 'meshes'))
+
+ cfg.setdefault('double_sided', False)
+ cfg.setdefault('pretrain_model_hf', None)
+ cfg.merge_with(cli_cfg)
+
+ """
+ [required]
+ model_name: str
+ image_input: str
+ export_video: bool
+ export_mesh: bool
+
+ [special]
+ source_size: int
+ render_size: int
+ video_dump: str
+ mesh_dump: str
+
+ [default]
+ render_views: int
+ render_fps: int
+ mesh_size: int
+ mesh_thres: float
+ frame_size: int
+ logger: str
+ """
+
+ cfg.setdefault('inferrer', {})
+ cfg['inferrer'].setdefault('logger', 'INFO')
+
+ # assert not (args.config is not None and args.infer is not None), "Only one of config and infer should be provided"
+ assert cfg.model_name is not None, "model_name is required"
+ if not os.environ.get('APP_ENABLED', None):
+ assert cfg.image_input is not None, "image_input is required"
+ assert cfg.export_video or cfg.export_mesh, \
+ "At least one of export_video or export_mesh should be True"
+ cfg.app_enabled = False
+ else:
+ cfg.app_enabled = True
+
+ return cfg
+
+
+@REGISTRY_RUNNERS.register('infer.lrm')
+class LRMInferrer(Inferrer):
+
+ EXP_TYPE: str = 'lrm'
+
+ def __init__(self):
+ super().__init__()
+
+ self.cfg = parse_configs()
+ configure_logger(
+ stream_level=self.cfg.inferrer.logger,
+ log_level=self.cfg.inferrer.logger,
+ )
+
+ self.model = self._build_model(self.cfg).to(self.device)
+
+ def _load_checkpoint(self, cfg):
+ ckpt_root = os.path.join(
+ cfg.saver.checkpoint_root,
+ cfg.experiment.parent, cfg.experiment.child,
+ )
+ if not os.path.exists(ckpt_root):
+ raise FileNotFoundError(f"The checkpoint directory '{ckpt_root}' does not exist.")
+ ckpt_dirs = os.listdir(ckpt_root)
+ iter_number = "{:06}".format(cfg.inferrer.iteration)
+ if iter_number not in ckpt_dirs:
+ raise FileNotFoundError(f"Checkpoint for iteration '{iter_number}' not found in '{ckpt_root}'.")
+ inferrer_ckpt_path = os.path.join(ckpt_root, iter_number, 'model.safetensors')
+ logger.info(f"======== Auto-resume from {inferrer_ckpt_path} ========")
+ return inferrer_ckpt_path
+
+ def _build_model(self, cfg):
+ from openlrm.models import model_dict
+ if cfg.inferrer.hugging_face is True: # for huggingface infer
+ hf_model_cls = wrap_model_hub(model_dict[self.EXP_TYPE])
+ model = hf_model_cls.from_pretrained(cfg.model_name)
+ if cfg.double_sided:
+ pretrain_model_path = hf_hub_download(repo_id=cfg.pretrain_model_hf, filename='model.safetensors')
+ safetensors.torch.load_model( # load the pretrain model after load the Tailor3D finetune part.
+ model,
+ pretrain_model_path,
+ strict=False
+ )
+ else: # for common infer
+ model = model_dict[self.EXP_TYPE](**cfg['model'])
+ inferrer_ckpt_path = self._load_checkpoint(cfg)
+ if cfg.double_sided:
+ pretrain_model_path = hf_hub_download(repo_id=cfg.pretrain_model_hf, filename='model.safetensors')
+ safetensors.torch.load_model( # load the pretrain model.
+ model,
+ pretrain_model_path,
+ strict=False
+ )
+ safetensors.torch.load_model( # load the finetune model.
+ model,
+ inferrer_ckpt_path,
+ strict=False
+ )
+ else:
+ safetensors.torch.load_model(
+ model,
+ inferrer_ckpt_path,
+ )
+ return model
+
+ @staticmethod
+ def save_images(images, output_path):
+ os.makedirs((output_path), exist_ok=True)
+ for i in range(images.shape[0]):
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
+ Image.fromarray(frame).save(os.path.join(output_path, f"{str(i)}.png"))
+
+ def _default_source_camera(self, dist_to_center: float = 2.0, batch_size: int = 1, device: torch.device = torch.device('cpu')):
+ # return: (N, D_cam_raw)
+ canonical_camera_extrinsics = torch.tensor([[
+ [1, 0, 0, 0],
+ [0, 0, -1, -dist_to_center],
+ [0, 1, 0, 0],
+ ]], dtype=torch.float32, device=device)
+ canonical_camera_intrinsics = create_intrinsics(
+ f=0.75,
+ c=0.5,
+ device=device,
+ ).unsqueeze(0)
+ source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics)
+ return source_camera.repeat(batch_size, 1)
+
+ def _default_render_cameras(self, n_views: int, batch_size: int = 1, device: torch.device = torch.device('cpu')):
+ # return: (N, M, D_cam_render)
+ render_camera_extrinsics = surrounding_views_linspace(n_views=n_views, device=device)
+ render_camera_intrinsics = create_intrinsics(
+ f=0.75,
+ c=0.5,
+ device=device,
+ ).unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1)
+ render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics)
+ return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)
+
+ def infer_planes(self, image: torch.Tensor, source_cam_dist: float, back_image=None):
+ N = image.shape[0]
+ source_camera = self._default_source_camera(dist_to_center=source_cam_dist, batch_size=N, device=self.device)
+ front_planes = self.model.forward_planes(image, source_camera)
+ if back_image is not None:
+ back_planes = self.model.forward_planes(back_image, source_camera)
+ # XY Plane
+ back_planes[:, 0, :, :, :] = torch.flip(back_planes[:, 0, :, :, :], dims=[-2, -1])
+ # XZ Plane
+ back_planes[:, 1, :, :, :] = torch.flip(back_planes[:, 1, :, :, :], dims=[-1])
+ # YZ Plane
+ back_planes[:, 2, :, :, :] = torch.flip(back_planes[:, 2, :, :, :], dims=[-2])
+
+ # To fuse the front planes and the back planes
+ bs, num_planes, channels, height, width = front_planes.shape
+ if 'conv_fuse' in self.cfg['model']:
+ planes = torch.cat((front_planes, back_planes), dim=2)
+ planes = planes.reshape(-1, channels*2, height, width)
+ # planes = self.model.front_back_conv(planes).view(bs, num_planes, -1, height, width) # only one layer.
+ # Apply multiple convolutional layers
+ for layer in self.model.front_back_conv:
+ planes = layer(planes)
+
+ planes = planes.view(bs, num_planes, -1, height, width)
+ elif 'swin_ca_fuse' in self.cfg['model']:
+ front_planes = front_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() # [8, 3, 32, 64, 64] -> [24, 32, 4096] -> [24, 4096, 32]
+ back_planes = back_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous()
+ planes = self.model.swin_cross_attention(front_planes, back_planes, height, width)[0].permute(0, 2, 1).reshape(bs, num_planes, channels, height, width)
+ else:
+ planes = front_planes
+
+ assert N == planes.shape[0]
+ return planes
+
+ def infer_video(self, planes: torch.Tensor, frame_size: int, render_size: int, render_views: int, render_fps: int, dump_video_path: str, image_format=False):
+ N = planes.shape[0]
+ render_cameras = self._default_render_cameras(n_views=render_views, batch_size=N, device=self.device)
+ render_anchors = torch.zeros(N, render_cameras.shape[1], 2, device=self.device)
+ render_resolutions = torch.ones(N, render_cameras.shape[1], 1, device=self.device) * render_size
+ render_bg_colors = torch.ones(N, render_cameras.shape[1], 1, device=self.device, dtype=torch.float32) * 1.
+
+ frames = []
+ for i in range(0, render_cameras.shape[1], frame_size):
+ frames.append(
+ self.model.synthesizer(
+ planes=planes,
+ cameras=render_cameras[:, i:i+frame_size],
+ anchors=render_anchors[:, i:i+frame_size],
+ resolutions=render_resolutions[:, i:i+frame_size],
+ bg_colors=render_bg_colors[:, i:i+frame_size],
+ region_size=render_size,
+ )
+ )
+ # merge frames
+ frames = {
+ k: torch.cat([r[k] for r in frames], dim=1)
+ for k in frames[0].keys()
+ }
+ # dump
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
+ for k, v in frames.items():
+ if k == 'images_rgb':
+ if image_format:
+ self.save_images( # save the rendering images directly.
+ v[0],
+ os.path.join(dump_video_path.replace('.mov', ''), 'nvs'),
+ )
+ else:
+ images_to_video(
+ images=v[0],
+ output_path=dump_video_path,
+ fps=render_fps,
+ gradio_codec=self.cfg.app_enabled,
+ )
+
+ def infer_mesh(self, planes: torch.Tensor, mesh_size: int, mesh_thres: float, dump_mesh_path: str):
+ grid_out = self.model.synthesizer.forward_grid(
+ planes=planes,
+ grid_size=mesh_size,
+ )
+
+ vtx, faces = mcubes.marching_cubes(grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres)
+ vtx = vtx / (mesh_size - 1) * 2 - 1
+
+ vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
+ vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1)
+ vtx_colors = (vtx_colors * 255).astype(np.uint8)
+
+ mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
+
+ # dump
+ os.makedirs(os.path.dirname(dump_mesh_path), exist_ok=True)
+ mesh.export(dump_mesh_path)
+
+ def infer_single(self, image_path: str, source_cam_dist: float, export_video: bool, export_mesh: bool, dump_video_path: str, dump_mesh_path: str, image_path_back=None):
+ source_size = self.cfg.inferrer.source_size
+ render_size = self.cfg.inferrer.render_size
+ render_views = self.cfg.inferrer.render_views
+ render_fps = self.cfg.inferrer.render_fps
+ mesh_size = self.cfg.inferrer.mesh_size
+ mesh_thres = self.cfg.inferrer.mesh_thres
+ frame_size = self.cfg.inferrer.frame_size
+ source_cam_dist = self.cfg.inferrer.source_cam_dist if source_cam_dist is None else source_cam_dist
+
+ image_format = self.cfg.inferrer.image_format
+
+ image = self.open_image(image_path, source_size)
+ if image_path_back is None:
+ back_image = self.open_image(image_path.replace('front', 'back'), source_size) if self.cfg.double_sided else None
+ else:
+ back_image = self.open_image(image_path_back, source_size) if self.cfg.double_sided else None
+
+ with torch.no_grad():
+ planes = self.infer_planes(image, source_cam_dist=source_cam_dist, back_image=back_image)
+
+ results = {}
+ if export_video:
+ frames = self.infer_video(planes, frame_size=frame_size, render_size=render_size, render_views=render_views, render_fps=render_fps, dump_video_path=dump_video_path,
+ image_format=image_format)
+ results.update({
+ 'frames': frames,
+ })
+ if export_mesh:
+ mesh = self.infer_mesh(planes, mesh_size=mesh_size, mesh_thres=mesh_thres, dump_mesh_path=dump_mesh_path)
+ results.update({
+ 'mesh': mesh,
+ })
+
+ def data_init(self):
+ image_paths = []
+ if os.path.isfile(self.cfg.image_input):
+ omit_prefix = os.path.dirname(self.cfg.image_input)
+ image_paths.append(self.cfg.image_input)
+ else:
+ omit_prefix = self.cfg.image_input
+ if self.cfg.double_sided: # double sided
+ walk_path = os.path.join(self.cfg.image_input, 'front')
+ else:
+ walk_path = self.cfg.image_input
+ for root, dirs, files in os.walk(walk_path):
+ for file in files:
+ if file.endswith('.png'):
+ image_paths.append(os.path.join(root, file))
+ image_paths.sort()
+ # alloc to each DDP worker
+ image_paths = image_paths[self.accelerator.process_index::self.accelerator.num_processes]
+
+ return image_paths, omit_prefix
+
+ def open_image(self, image_path, source_size):
+ # prepare image: [1, C_img, H_img, W_img], 0-1 scale
+ image = torch.from_numpy(np.array(Image.open(image_path))).to(self.device)
+ image = image.permute(2, 0, 1).unsqueeze(0) / 255.0
+ if image.shape[1] == 4: # RGBA
+ image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
+ image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
+ image = torch.clamp(image, 0, 1)
+
+ return image
+
+ def infer(self):
+ image_paths, omit_prefix = self.data_init()
+ for image_path in tqdm(image_paths, disable=not self.accelerator.is_local_main_process):
+
+ # prepare dump paths
+ image_name = os.path.basename(image_path)
+ uid = image_name.split('.')[0]
+ subdir_path = os.path.dirname(image_path).replace(omit_prefix, '')
+ subdir_path = subdir_path[1:] if subdir_path.startswith('/') else subdir_path
+ dump_video_path = os.path.join(
+ self.cfg.video_dump,
+ subdir_path,
+ f'{uid}.mov',
+ )
+ dump_mesh_path = os.path.join(
+ self.cfg.mesh_dump,
+ subdir_path,
+ f'{uid}.ply',
+ )
+
+ self.infer_single(
+ image_path,
+ source_cam_dist=None,
+ export_video=self.cfg.export_video,
+ export_mesh=self.cfg.export_mesh,
+ dump_video_path=dump_video_path,
+ dump_mesh_path=dump_mesh_path,
+ )
diff --git a/openlrm/runners/train/__init__.py b/openlrm/runners/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e837a0ba6fd8894e3fec916b72e9543ebc2b3db2
--- /dev/null
+++ b/openlrm/runners/train/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .lrm import LRMTrainer
diff --git a/openlrm/runners/train/base_trainer.py b/openlrm/runners/train/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..420455e308e94a6768c8fef93b74c5c5a38b4305
--- /dev/null
+++ b/openlrm/runners/train/base_trainer.py
@@ -0,0 +1,385 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import time
+import math
+import argparse
+import shutil
+import torch
+import safetensors
+from omegaconf import OmegaConf
+from abc import abstractmethod
+from contextlib import contextmanager
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+
+from openlrm.utils.logging import configure_logger
+from openlrm.utils.compile import configure_dynamo
+from openlrm.runners.abstract import Runner
+
+from collections import OrderedDict
+from huggingface_hub import hf_hub_download
+
+# def my_save_pre_hook(models, weights, output_dir):
+# keep = ["_lora", "synthesizer", "front_back_conv"]
+# for weight_dict in weights:
+# keys_to_keep = [key for key in weight_dict if any(keep_str in key for keep_str in keep)]
+# new_weight_dict = OrderedDict((key, weight_dict[key]) for key in keys_to_keep)
+# weight_dict.clear()
+# weight_dict.update(new_weight_dict)
+
+from collections import OrderedDict
+
+def my_save_pre_hook(models, weights, output_dir):
+ assert len(models) == len(weights), "Models and weights must correspond one-to-one"
+
+ filtered_weights_list = []
+ for model, model_weights in zip(models, weights):
+ filtered_weights = OrderedDict()
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if name in model_weights:
+ filtered_weights[name] = model_weights[name]
+
+ filtered_weights_list.append(filtered_weights)
+
+ weights.clear()
+ weights.extend(filtered_weights_list)
+
+
+logger = get_logger(__name__)
+
+
+def parse_configs():
+ # Define argparse arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, default='./assets/config.yaml')
+ args, unknown = parser.parse_known_args()
+
+ # Load configuration file
+ cfg = OmegaConf.load(args.config)
+
+ # Override with command-line arguments
+ cli_cfg = OmegaConf.from_cli(unknown)
+ cfg = OmegaConf.merge(cfg, cli_cfg)
+
+ return cfg
+
+class Trainer(Runner):
+
+ def __init__(self):
+ super().__init__()
+
+ self.cfg = parse_configs()
+ self.timestamp = time.strftime("%Y%m%d-%H%M%S")
+
+ self.accelerator = Accelerator(
+ mixed_precision=self.cfg.train.mixed_precision,
+ gradient_accumulation_steps=self.cfg.train.accum_steps,
+ log_with=tuple(self.cfg.logger.trackers),
+ project_config=ProjectConfiguration(
+ logging_dir=self.cfg.logger.tracker_root,
+ ),
+ use_seedable_sampler=True,
+ kwargs_handlers=[
+ DistributedDataParallelKwargs(
+ find_unused_parameters=self.cfg.train.find_unused_parameters,
+ ),
+ ],
+ )
+ self.accelerator.register_save_state_pre_hook(my_save_pre_hook) # it is the save model hook.
+
+ set_seed(self.cfg.experiment.seed, device_specific=True)
+ with self.accelerator.main_process_first():
+ configure_logger(
+ stream_level=self.cfg.logger.stream_level,
+ log_level=self.cfg.logger.log_level,
+ file_path=os.path.join(
+ self.cfg.logger.log_root,
+ self.cfg.experiment.parent, self.cfg.experiment.child,
+ f"{self.timestamp}.log",
+ ) if self.accelerator.is_main_process else None,
+ )
+ logger.info(self.accelerator.state, main_process_only=False, in_order=True)
+ configure_dynamo(dict(self.cfg.compile))
+
+ # attributes with defaults
+ self.model : torch.nn.Module = None
+ self.optimizer: torch.optim.Optimizer = None
+ self.scheduler: torch.optim.lr_scheduler.LRScheduler = None
+ self.train_loader: torch.utils.data.DataLoader = None
+ self.val_loader: torch.utils.data.DataLoader = None
+ self.N_max_global_steps: int = None
+ self.N_global_steps_per_epoch: int = None
+ self.global_step: int = 0
+ self.current_epoch: int = 0
+
+ def __enter__(self):
+ self.accelerator.init_trackers(
+ project_name=f"{self.cfg.experiment.parent}/{self.cfg.experiment.child}",
+ )
+ self.prepare_everything()
+ self.log_inital_info()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.accelerator.end_training()
+
+ @staticmethod
+ def control(option: str = None, synchronized: bool = False):
+ def decorator(func):
+ def wrapper(self, *args, **kwargs):
+ if option is None or hasattr(self.accelerator, option):
+ accelerated_func = getattr(self.accelerator, option)(func) if option is not None else func
+ result = accelerated_func(self, *args, **kwargs)
+ if synchronized:
+ self.accelerator.wait_for_everyone()
+ return result
+ else:
+ raise AttributeError(f"Accelerator has no attribute {option}")
+ return wrapper
+ return decorator
+
+ @contextmanager
+ def exec_in_order(self):
+ for rank in range(self.accelerator.num_processes):
+ try:
+ if self.accelerator.process_index == rank:
+ yield
+ finally:
+ self.accelerator.wait_for_everyone()
+
+ @property
+ def device(self):
+ return self.accelerator.device
+
+ @property
+ def is_distributed(self) -> bool:
+ return self.accelerator.num_processes > 1
+
+ def prepare_everything(self, is_dist_validation: bool = True):
+ # prepare with accelerator
+ if is_dist_validation:
+ self.model, self.optimizer, self.train_loader, self.val_loader = \
+ self.accelerator.prepare(
+ self.model, self.optimizer, self.train_loader, self.val_loader,
+ )
+ else:
+ self.model, self.optimizer, self.train_loader = \
+ self.accelerator.prepare(
+ self.model, self.optimizer, self.train_loader,
+ )
+ self.accelerator.register_for_checkpointing(self.scheduler)
+ # prepare stats
+ N_total_batch_size = self.cfg.train.batch_size * self.accelerator.num_processes * self.cfg.train.accum_steps
+ self.N_global_steps_per_epoch = math.ceil(len(self.train_loader) / self.cfg.train.accum_steps)
+ self.N_max_global_steps = self.N_global_steps_per_epoch * self.cfg.train.epochs
+ if self.cfg.train.debug_global_steps is not None:
+ logger.warning(f"Overriding max global steps from {self.N_max_global_steps} to {self.cfg.train.debug_global_steps}")
+ self.N_max_global_steps = self.cfg.train.debug_global_steps
+ logger.info(f"======== Statistics ========")
+ logger.info(f"** N_max_global_steps: {self.N_max_global_steps}")
+ logger.info(f"** N_total_batch_size: {N_total_batch_size}")
+ logger.info(f"** N_epochs: {self.cfg.train.epochs}")
+ logger.info(f"** N_global_steps_per_epoch: {self.N_global_steps_per_epoch}")
+ logger.debug(f"** Prepared loader length: {len(self.train_loader)}")
+ logger.info(f"** Distributed validation: {is_dist_validation}")
+ logger.info(f"============================")
+ logger.info(f"======== Trainable parameters ========")
+ logger.info(f"** Total: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
+ for sub_name, sub_module in self.accelerator.unwrap_model(self.model).named_children():
+ logger.info(f"** {sub_name}: {sum(p.numel() for p in sub_module.parameters() if p.requires_grad)}")
+ logger.info(f"=====================================")
+ self.accelerator.wait_for_everyone()
+ # load checkpoint or model
+ self.load_ckpt_or_auto_resume_(self.cfg)
+ # register hooks
+ self.register_hooks()
+
+ @abstractmethod
+ def register_hooks(self):
+ pass
+
+ def auto_resume_(self, cfg) -> bool:
+ ckpt_root = os.path.join(
+ cfg.saver.checkpoint_root,
+ cfg.experiment.parent, cfg.experiment.child,
+ )
+ if not os.path.exists(ckpt_root):
+ return False
+ ckpt_dirs = os.listdir(ckpt_root)
+ if len(ckpt_dirs) == 0:
+ return False
+ ckpt_dirs.sort()
+ latest_ckpt = ckpt_dirs[-1]
+ latest_ckpt_dir = os.path.join(ckpt_root, latest_ckpt)
+ logger.info(f"======== Auto-resume from {latest_ckpt_dir} ========")
+ self.accelerator.load_state(latest_ckpt_dir, strict=cfg.saver.load_model_func_kwargs.strict)
+ self.global_step = int(latest_ckpt)
+ self.current_epoch = self.global_step // self.N_global_steps_per_epoch
+ return True
+
+ def load_model_(self, cfg):
+ if cfg.saver.load_model.type == 'hugging_face':
+ repo_id, file_name = os.path.dirname(cfg.saver.load_model.url), os.path.basename(cfg.saver.load_model.url)
+ pretrain_model_path = hf_hub_download(repo_id=repo_id, filename=file_name)
+ logger.info(f"======== Loading pretrain model from hugging face {repo_id, file_name} ========")
+ safetensors.torch.load_model(
+ self.accelerator.unwrap_model(self.model),
+ pretrain_model_path,
+ **cfg.saver.load_model_func_kwargs
+ )
+ logger.info(f"======== Pretrain Model loaded ========")
+ return True
+ else:
+ logger.info(f"======== Loading model from {cfg.saver.load_model} ========")
+ safetensors.torch.load_model(
+ self.accelerator.unwrap_model(self.model),
+ cfg.saver.load_model,
+ strict=True,
+ )
+ logger.info(f"======== Model loaded ========")
+ return True
+
+ @control(synchronized=True)
+ def load_ckpt_or_auto_resume_(self, cfg):
+ # auto resume has higher priority, load model from path if auto resume is not available
+ # cfg.saver.auto_resume and cfg.saver.load_model
+ if cfg.saver.auto_resume:
+ successful_resume = self.auto_resume_(cfg)
+ if successful_resume:
+ if cfg.saver.load_model:
+ successful_load = self.load_model_(cfg)
+ if successful_load:
+ return
+ return
+ if cfg.saver.load_model:
+ successful_load = self.load_model_(cfg)
+ if successful_load:
+ return
+ logger.debug(f"======== No checkpoint or model is loaded ========")
+
+ @control('on_main_process', synchronized=True)
+ def save_checkpoint(self):
+ ckpt_dir = os.path.join(
+ self.cfg.saver.checkpoint_root,
+ self.cfg.experiment.parent, self.cfg.experiment.child,
+ f"{self.global_step:06d}",
+ )
+ self.accelerator.save_state(output_dir=ckpt_dir, safe_serialization=True)
+ logger.info(f"======== Saved checkpoint at global step {self.global_step} ========")
+ # manage stratified checkpoints
+ ckpt_dirs = os.listdir(os.path.dirname(ckpt_dir))
+ ckpt_dirs.sort()
+ max_ckpt = int(ckpt_dirs[-1])
+ ckpt_base = int(self.cfg.saver.checkpoint_keep_level)
+ ckpt_period = self.cfg.saver.checkpoint_global_steps
+ logger.debug(f"Checkpoint base: {ckpt_base}")
+ logger.debug(f"Checkpoint period: {ckpt_period}")
+ cur_order = ckpt_base ** math.floor(math.log(max_ckpt // ckpt_period, ckpt_base))
+ cur_idx = 0
+ while cur_order > 0:
+ cur_digit = max_ckpt // ckpt_period // cur_order % ckpt_base
+ while cur_idx < len(ckpt_dirs) and int(ckpt_dirs[cur_idx]) // ckpt_period // cur_order % ckpt_base < cur_digit:
+ if int(ckpt_dirs[cur_idx]) // ckpt_period % cur_order != 0:
+ shutil.rmtree(os.path.join(os.path.dirname(ckpt_dir), ckpt_dirs[cur_idx]))
+ logger.info(f"Removed checkpoint {ckpt_dirs[cur_idx]}")
+ cur_idx += 1
+ cur_order //= ckpt_base
+
+ @property
+ def global_step_in_epoch(self):
+ return self.global_step % self.N_global_steps_per_epoch
+
+ @abstractmethod
+ def _build_model(self):
+ pass
+
+ @abstractmethod
+ def _build_optimizer(self):
+ pass
+
+ @abstractmethod
+ def _build_scheduler(self):
+ pass
+
+ @abstractmethod
+ def _build_dataloader(self):
+ pass
+
+ @abstractmethod
+ def _build_loss_fn(self):
+ pass
+
+ @abstractmethod
+ def train(self):
+ pass
+
+ @abstractmethod
+ def evaluate(self):
+ pass
+
+ @staticmethod
+ def _get_str_progress(epoch: int = None, step: int = None):
+ if epoch is not None:
+ log_type = 'epoch'
+ log_progress = epoch
+ elif step is not None:
+ log_type = 'step'
+ log_progress = step
+ else:
+ raise ValueError('Either epoch or step must be provided')
+ return log_type, log_progress
+
+ @control('on_main_process')
+ def log_scalar_kwargs(self, epoch: int = None, step: int = None, split: str = None, **scalar_kwargs):
+ log_type, log_progress = self._get_str_progress(epoch, step)
+ split = f'/{split}' if split else ''
+ for key, value in scalar_kwargs.items():
+ self.accelerator.log({f'{key}{split}/{log_type}': value}, log_progress)
+
+ @control('on_main_process')
+ def log_images(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}):
+ for tracker in self.accelerator.trackers:
+ if hasattr(tracker, 'log_images'):
+ tracker.log_images(values, step=step, **log_kwargs.get(tracker.name, {}))
+
+ @control('on_main_process')
+ def log_optimizer(self, epoch: int = None, step: int = None, attrs: list[str] = [], group_ids: list[int] = []):
+ log_type, log_progress = self._get_str_progress(epoch, step)
+ assert self.optimizer is not None, 'Optimizer is not initialized'
+ if not attrs:
+ logger.warning('No optimizer attributes are provided, nothing will be logged')
+ if not group_ids:
+ logger.warning('No optimizer group ids are provided, nothing will be logged')
+ for attr in attrs:
+ assert attr in ['lr', 'momentum', 'weight_decay'], f'Invalid optimizer attribute {attr}'
+ for group_id in group_ids:
+ self.accelerator.log({f'opt/{attr}/{group_id}': self.optimizer.param_groups[group_id][attr]}, log_progress)
+
+ @control('on_main_process')
+ def log_inital_info(self):
+ assert self.model is not None, 'Model is not initialized'
+ assert self.optimizer is not None, 'Optimizer is not initialized'
+ assert self.scheduler is not None, 'Scheduler is not initialized'
+ self.accelerator.log({'Config': "```\n" + OmegaConf.to_yaml(self.cfg) + "\n```"})
+ self.accelerator.log({'Model': "```\n" + str(self.model) + "\n```"})
+ self.accelerator.log({'Optimizer': "```\n" + str(self.optimizer) + "\n```"})
+ self.accelerator.log({'Scheduler': "```\n" + str(self.scheduler) + "\n```"})
+
+ def run(self):
+ self.train()
diff --git a/openlrm/runners/train/lrm.py b/openlrm/runners/train/lrm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad5933d006a7a755cdc45e6ebc101127dd8b5c66
--- /dev/null
+++ b/openlrm/runners/train/lrm.py
@@ -0,0 +1,427 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import math
+from tqdm.auto import tqdm
+import torch
+import torch.nn as nn
+from torchvision.utils import make_grid
+from accelerate.logging import get_logger
+
+from .base_trainer import Trainer
+from openlrm.utils.profiler import DummyProfiler
+from openlrm.runners import REGISTRY_RUNNERS
+
+
+logger = get_logger(__name__)
+
+
+@REGISTRY_RUNNERS.register('train.lrm')
+class LRMTrainer(Trainer):
+ def __init__(self):
+ super().__init__()
+
+ self.model = self._build_model(self.cfg)
+ self.optimizer = self._build_optimizer(self.model, self.cfg)
+ self.train_loader, self.val_loader = self._build_dataloader(self.cfg)
+ self.scheduler = self._build_scheduler(self.optimizer, self.cfg)
+ self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg)
+
+ def _build_model(self, cfg):
+ assert cfg.experiment.type == 'lrm', \
+ f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}"
+ from openlrm.models import ModelLRM
+ model = ModelLRM(**cfg.model)
+ return model
+
+ def _build_optimizer(self, model: nn.Module, cfg):
+ decay_params, no_decay_params = [], []
+
+ # add all bias and LayerNorm params to no_decay_params
+ for name, module in model.named_modules():
+ if isinstance(module, nn.LayerNorm):
+ no_decay_params.extend([p for p in module.parameters()])
+ elif hasattr(module, 'bias') and module.bias is not None:
+ no_decay_params.append(module.bias)
+
+ # add remaining parameters to decay_params
+ _no_decay_ids = set(map(id, no_decay_params))
+ decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids]
+
+ # filter out parameters with no grad
+ decay_params = list(filter(lambda p: p.requires_grad, decay_params))
+ no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params))
+
+ # monitor this to make sure we don't miss any parameters
+ logger.info("======== Weight Decay Parameters ========")
+ logger.info(f"Total: {len(decay_params)}")
+ logger.info("======== No Weight Decay Parameters ========")
+ logger.info(f"Total: {len(no_decay_params)}")
+
+ # Optimizer
+ opt_groups = [
+ {'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay},
+ {'params': no_decay_params, 'weight_decay': 0.0},
+ ]
+ optimizer = torch.optim.AdamW(
+ opt_groups,
+ lr=cfg.train.optim.lr,
+ betas=(cfg.train.optim.beta1, cfg.train.optim.beta2),
+ )
+
+ return optimizer
+
+ def _build_scheduler(self, optimizer, cfg):
+ local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes)
+ total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps)
+ effective_warmup_iters = cfg.train.scheduler.warmup_real_iters
+ logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========")
+ logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========")
+ if cfg.train.scheduler.type == 'cosine':
+ from openlrm.utils.scheduler import CosineWarmupScheduler
+ scheduler = CosineWarmupScheduler(
+ optimizer=optimizer,
+ warmup_iters=effective_warmup_iters,
+ max_iters=total_global_batches,
+ )
+ else:
+ raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented")
+ return scheduler
+
+ def _build_dataloader(self, cfg):
+ # dataset class
+ from openlrm.datasets import MixerDataset
+
+ # build dataset
+ train_dataset = MixerDataset(
+ split="train",
+ subsets=cfg.dataset.subsets,
+ sample_side_views=cfg.dataset.sample_side_views,
+ render_image_res_low=cfg.dataset.render_image.low,
+ render_image_res_high=cfg.dataset.render_image.high,
+ render_region_size=cfg.dataset.render_image.region,
+ source_image_res=cfg.dataset.source_image_res,
+ normalize_camera=cfg.dataset.normalize_camera,
+ normed_dist_to_center=cfg.dataset.normed_dist_to_center,
+ )
+ val_dataset = MixerDataset(
+ split="val",
+ subsets=cfg.dataset.subsets,
+ sample_side_views=cfg.dataset.sample_side_views,
+ render_image_res_low=cfg.dataset.render_image.low,
+ render_image_res_high=cfg.dataset.render_image.high,
+ render_region_size=cfg.dataset.render_image.region,
+ source_image_res=cfg.dataset.source_image_res,
+ normalize_camera=cfg.dataset.normalize_camera,
+ normed_dist_to_center=cfg.dataset.normed_dist_to_center,
+ )
+
+ # build data loader
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=cfg.train.batch_size,
+ shuffle=True,
+ drop_last=True,
+ num_workers=cfg.dataset.num_train_workers,
+ pin_memory=cfg.dataset.pin_mem,
+ persistent_workers=True,
+ )
+ val_loader = torch.utils.data.DataLoader(
+ val_dataset,
+ batch_size=cfg.val.batch_size,
+ shuffle=False,
+ drop_last=False,
+ num_workers=cfg.dataset.num_val_workers,
+ pin_memory=cfg.dataset.pin_mem,
+ persistent_workers=False,
+ )
+
+ return train_loader, val_loader
+
+ def _build_loss_fn(self, cfg):
+ from openlrm.losses import PixelLoss, LPIPSLoss, TVLoss
+ pixel_loss_fn = PixelLoss()
+ with self.accelerator.main_process_first():
+ perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True)
+ tv_loss_fn = TVLoss()
+ return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn
+
+ def register_hooks(self):
+ pass
+
+ def forward_loss_local_step(self, data):
+
+ source_camera = data['source_camera']
+ render_camera = data['render_camera']
+ source_image = data['source_image']
+ render_image = data['render_image']
+ if 'source_image_back' in data:
+ source_image_back = data['source_image_back'] #!!!
+ else:
+ source_image_back = None
+ render_anchors = data['render_anchors']
+ render_full_resolutions = data['render_full_resolutions']
+ render_bg_colors = data['render_bg_colors']
+
+ N, M, C, H, W = render_image.shape
+
+ # forward
+ outputs = self.model(
+ image=source_image,
+ source_camera=source_camera,
+ render_cameras=render_camera,
+ render_anchors=render_anchors,
+ render_resolutions=render_full_resolutions,
+ render_bg_colors=render_bg_colors,
+ render_region_size=self.cfg.dataset.render_image.region,
+ image_back=source_image_back, #!!!
+ )
+
+ # loss calculation
+ loss = 0.
+ loss_pixel = None
+ loss_perceptual = None
+ loss_tv = None
+
+ if self.cfg.train.loss.pixel_weight > 0.:
+ loss_pixel = self.pixel_loss_fn(outputs['images_rgb'], render_image)
+ loss += loss_pixel * self.cfg.train.loss.pixel_weight
+ if self.cfg.train.loss.perceptual_weight > 0.:
+ loss_perceptual = self.perceptual_loss_fn(outputs['images_rgb'], render_image)
+ loss += loss_perceptual * self.cfg.train.loss.perceptual_weight
+ if self.cfg.train.loss.tv_weight > 0.:
+ loss_tv = self.tv_loss_fn(outputs['planes'])
+ loss += loss_tv * self.cfg.train.loss.tv_weight
+
+ return outputs, loss, loss_pixel, loss_perceptual, loss_tv
+
+ def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile):
+ self.model.train()
+
+ local_step_losses = []
+ global_step_losses = []
+
+ logger.debug(f"======== Starting epoch {self.current_epoch} ========")
+ for data in loader:
+
+ logger.debug(f"======== Starting global step {self.global_step} ========")
+ with self.accelerator.accumulate(self.model):
+
+ # forward to loss
+ outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data)
+
+ # backward
+ self.accelerator.backward(loss)
+ if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm)
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ # track local losses
+ local_step_losses.append(torch.stack([
+ _loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device)
+ for _loss in [loss, loss_pixel, loss_perceptual, loss_tv]
+ ]))
+
+ # track global step
+ if self.accelerator.sync_gradients:
+ profiler.step()
+ self.scheduler.step()
+ logger.debug(f"======== Scheduler step ========")
+ self.global_step += 1
+ global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu()
+ loss, loss_pixel, loss_perceptual, loss_tv = global_step_loss.unbind()
+ loss_kwargs = {
+ 'loss': loss.item(),
+ 'loss_pixel': loss_pixel.item(),
+ 'loss_perceptual': loss_perceptual.item(),
+ 'loss_tv': loss_tv.item(),
+ }
+ self.log_scalar_kwargs(
+ step=self.global_step, split='train',
+ **loss_kwargs
+ )
+ self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1])
+ local_step_losses = []
+ global_step_losses.append(global_step_loss)
+
+ # manage display
+ pbar.update(1)
+ description = {
+ **loss_kwargs,
+ 'lr': self.optimizer.param_groups[0]['lr'],
+ }
+ description = '[TRAIN STEP]' + \
+ ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v))
+ pbar.set_description(description)
+
+ # periodic actions
+ if self.global_step % self.cfg.saver.checkpoint_global_steps == 0:
+ self.save_checkpoint()
+ if self.global_step % self.cfg.val.global_step_period == 0:
+ self.evaluate()
+ self.model.train()
+ if self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0:
+ self.log_image_monitor(
+ step=self.global_step, split='train',
+ renders=outs['images_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ )
+
+ # progress control
+ if self.global_step >= self.N_max_global_steps:
+ self.accelerator.set_trigger()
+ break
+
+ # track epoch
+ self.current_epoch += 1
+ epoch_losses = torch.stack(global_step_losses).mean(dim=0)
+ epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv = epoch_losses.unbind()
+ epoch_loss_dict = {
+ 'loss': epoch_loss.item(),
+ 'loss_pixel': epoch_loss_pixel.item(),
+ 'loss_perceptual': epoch_loss_perceptual.item(),
+ 'loss_tv': epoch_loss_tv.item(),
+ }
+ self.log_scalar_kwargs(
+ epoch=self.current_epoch, split='train',
+ **epoch_loss_dict,
+ )
+ logger.info(
+ f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \
+ ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v))
+ )
+
+ def train(self):
+
+ starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps
+ skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch)
+ logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========")
+
+ with tqdm(
+ range(0, self.N_max_global_steps),
+ initial=self.global_step,
+ disable=(not self.accelerator.is_main_process),
+ ) as pbar:
+
+ profiler = torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
+ schedule=torch.profiler.schedule(
+ wait=10, warmup=10, active=100,
+ ),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(
+ self.cfg.logger.tracker_root,
+ self.cfg.experiment.parent, self.cfg.experiment.child,
+ )),
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ ) if self.cfg.logger.enable_profiler else DummyProfiler()
+
+ with profiler:
+
+ self.optimizer.zero_grad()
+ for _ in range(self.current_epoch, self.cfg.train.epochs):
+
+ loader = skipped_loader or self.train_loader
+ skipped_loader = None
+ self.train_epoch(pbar=pbar, loader=loader, profiler=profiler)
+ if self.accelerator.check_trigger():
+ break
+
+ logger.info(f"======== Training finished at global step {self.global_step} ========")
+
+ # final checkpoint and evaluation
+ self.save_checkpoint()
+ self.evaluate()
+
+ @torch.no_grad()
+ @torch.compiler.disable
+ def evaluate(self, epoch: int = None):
+ self.model.eval()
+
+ max_val_batches = self.cfg.val.debug_batches or len(self.val_loader)
+ running_losses = []
+ sample_data, sample_outs = None, None
+
+ for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches):
+
+ if len(running_losses) >= max_val_batches:
+ logger.info(f"======== Early stop validation at {len(running_losses)} batches ========")
+ break
+
+ outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data)
+ sample_data, sample_outs = data, outs
+
+ running_losses.append(torch.stack([
+ _loss if _loss is not None else torch.tensor(float('nan'), device=self.device)
+ for _loss in [loss, loss_pixel, loss_perceptual, loss_tv]
+ ]))
+
+ total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu()
+ total_loss, total_loss_pixel, total_loss_perceptual, total_loss_tv = total_losses.unbind()
+ total_loss_dict = {
+ 'loss': total_loss.item(),
+ 'loss_pixel': total_loss_pixel.item(),
+ 'loss_perceptual': total_loss_perceptual.item(),
+ 'loss_tv': total_loss_tv.item(),
+ }
+
+ if epoch is not None:
+ self.log_scalar_kwargs(
+ epoch=epoch, split='val',
+ **total_loss_dict,
+ )
+ logger.info(
+ f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \
+ ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
+ )
+ self.log_image_monitor(
+ epoch=epoch, split='val',
+ renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ )
+ else:
+ self.log_scalar_kwargs(
+ step=self.global_step, split='val',
+ **total_loss_dict,
+ )
+ logger.info(
+ f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \
+ ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
+ )
+ self.log_image_monitor(
+ step=self.global_step, split='val',
+ renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
+ )
+
+ @Trainer.control('on_main_process')
+ def log_image_monitor(
+ self, epoch: int = None, step: int = None, split: str = None,
+ renders: torch.Tensor = None, gts: torch.Tensor = None,
+ ):
+ M = renders.shape[1]
+ merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:])
+ renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:])
+ renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M)
+ log_type, log_progress = self._get_str_progress(epoch, step)
+ split = f'/{split}' if split else ''
+ self.log_images({
+ f'Images_split{split}/rendered': renders.unsqueeze(0),
+ f'Images_split{split}/gt': gts.unsqueeze(0),
+ f'Images_merged{split}': merged.unsqueeze(0),
+ }, log_progress)
diff --git a/openlrm/utils/__init__.py b/openlrm/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6
--- /dev/null
+++ b/openlrm/utils/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Empty
diff --git a/openlrm/utils/compile.py b/openlrm/utils/compile.py
new file mode 100644
index 0000000000000000000000000000000000000000..08972a23daf1c046c327ce93fc667b706a3ec65b
--- /dev/null
+++ b/openlrm/utils/compile.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+def configure_dynamo(config: dict):
+ try:
+ import torch._dynamo
+ logger.debug(f'Configuring torch._dynamo.config with {config}')
+ for k, v in config.items():
+ if v is None:
+ logger.debug(f'Skipping torch._dynamo.config.{k} with None')
+ continue
+ if hasattr(torch._dynamo.config, k):
+ logger.warning(f'Overriding torch._dynamo.config.{k} from {getattr(torch._dynamo.config, k)} to {v}')
+ setattr(torch._dynamo.config, k, v)
+ except ImportError:
+ logger.debug('torch._dynamo not found, skipping')
+ pass
diff --git a/openlrm/utils/hf_hub.py b/openlrm/utils/hf_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9ba0df56983a407d20c2c656a82c1ad15487ca5
--- /dev/null
+++ b/openlrm/utils/hf_hub.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin
+
+
+def wrap_model_hub(model_cls: nn.Module):
+ class HfModel(model_cls, PyTorchModelHubMixin):
+ def __init__(self, config: dict):
+ super().__init__(**config)
+ self.config = config
+ return HfModel
diff --git a/openlrm/utils/logging.py b/openlrm/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e2ecd77ff0d1dc9b7fa5cb4efc6edcda8a18d0d
--- /dev/null
+++ b/openlrm/utils/logging.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import logging
+from tqdm.auto import tqdm
+
+
+class TqdmStreamHandler(logging.StreamHandler):
+ def emit(self, record):
+ tqdm.write(self.format(record))
+
+
+def configure_logger(stream_level, log_level, file_path = None):
+ _stream_level = stream_level.upper()
+ _log_level = log_level.upper()
+ _project_level = _log_level
+
+ _formatter = logging.Formatter("[%(asctime)s] %(name)s: [%(levelname)s] %(message)s")
+
+ _stream_handler = TqdmStreamHandler()
+ _stream_handler.setLevel(_stream_level)
+ _stream_handler.setFormatter(_formatter)
+
+ if file_path is not None:
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ _file_handler = logging.FileHandler(file_path)
+ _file_handler.setLevel(_log_level)
+ _file_handler.setFormatter(_formatter)
+
+ _project_logger = logging.getLogger(__name__.split('.')[0])
+ _project_logger.setLevel(_project_level)
+ _project_logger.addHandler(_stream_handler)
+ if file_path is not None:
+ _project_logger.addHandler(_file_handler)
diff --git a/openlrm/utils/preprocess.py b/openlrm/utils/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..81ea7663c6040f27ff61917f0c85fe72c488f898
--- /dev/null
+++ b/openlrm/utils/preprocess.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import rembg
+import cv2
+import os
+
+def save_image_with_directory_check(save_path, image):
+ directory = os.path.dirname(save_path)
+
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ return cv2.imwrite(save_path, image)
+
+class Preprocessor:
+
+ """
+ Preprocessing under cv2 conventions.
+ """
+
+ def __init__(self):
+ self.rembg_session = rembg.new_session(
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
+ )
+
+ def preprocess(self, image_path: str, save_path: str, rmbg: bool = True, recenter: bool = True, size: int = 512, border_ratio: float = 0.2):
+ image = self.step_load_to_size(image_path=image_path, size=size*2)
+ if rmbg:
+ image = self.step_rembg(image_in=image)
+ else:
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)
+ if recenter:
+ image = self.step_recenter(image_in=image, border_ratio=border_ratio, square_size=size)
+ else:
+ image = cv2.resize(
+ src=image,
+ dsize=(size, size),
+ interpolation=cv2.INTER_AREA,
+ )
+ return save_image_with_directory_check(save_path, image)
+
+ def step_rembg(self, image_in: np.ndarray) -> np.ndarray:
+ image_out = rembg.remove(
+ data=image_in,
+ session=self.rembg_session,
+ )
+ return image_out
+
+ def step_recenter(self, image_in: np.ndarray, border_ratio: float, square_size: int) -> np.ndarray:
+ assert image_in.shape[-1] == 4, "Image to recenter must be RGBA"
+ mask = image_in[..., -1] > 0
+ ijs = np.nonzero(mask)
+ # find bbox
+ i_min, i_max = ijs[0].min(), ijs[0].max()
+ j_min, j_max = ijs[1].min(), ijs[1].max()
+ bbox_height, bbox_width = i_max - i_min, j_max - j_min
+ # recenter and resize
+ desired_size = int(square_size * (1 - border_ratio))
+ scale = desired_size / max(bbox_height, bbox_width)
+ desired_height, desired_width = int(bbox_height * scale), int(bbox_width * scale)
+ desired_i_min, desired_j_min = (square_size - desired_height) // 2, (square_size - desired_width) // 2
+ desired_i_max, desired_j_max = desired_i_min + desired_height, desired_j_min + desired_width
+ # create new image
+ image_out = np.zeros((square_size, square_size, 4), dtype=np.uint8)
+ image_out[desired_i_min:desired_i_max, desired_j_min:desired_j_max] = cv2.resize(
+ src=image_in[i_min:i_max, j_min:j_max],
+ dsize=(desired_width, desired_height),
+ interpolation=cv2.INTER_AREA,
+ )
+ return image_out
+
+ def step_load_to_size(self, image_path: str, size: int) -> np.ndarray:
+ image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
+ height, width = image.shape[:2]
+ scale = size / max(height, width)
+ height, width = int(height * scale), int(width * scale)
+ image_out = cv2.resize(
+ src=image,
+ dsize=(width, height),
+ interpolation=cv2.INTER_AREA,
+ )
+ return image_out
diff --git a/openlrm/utils/profiler.py b/openlrm/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ba79973308b627d5b20bdd7bb09eac138c93ad
--- /dev/null
+++ b/openlrm/utils/profiler.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from torch.profiler import profile
+
+
+class DummyProfiler(profile):
+ def __init__(self):
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ pass
+
+ def step(self):
+ pass
diff --git a/openlrm/utils/proxy.py b/openlrm/utils/proxy.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a0c6e53f0a0c412debc866d172926a4c0401bba
--- /dev/null
+++ b/openlrm/utils/proxy.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+NO_PROXY = "OPENLRM_NO_DATA_PROXY" in os.environ
+
+def no_proxy(func):
+ """Decorator to disable proxy but then restore after the function call."""
+ def wrapper(*args, **kwargs):
+ # http_proxy, https_proxy, HTTP_PROXY, HTTPS_PROXY, all_proxy
+ http_proxy = os.environ.get('http_proxy')
+ https_proxy = os.environ.get('https_proxy')
+ HTTP_PROXY = os.environ.get('HTTP_PROXY')
+ HTTPS_PROXY = os.environ.get('HTTPS_PROXY')
+ all_proxy = os.environ.get('all_proxy')
+ os.environ['http_proxy'] = ''
+ os.environ['https_proxy'] = ''
+ os.environ['HTTP_PROXY'] = ''
+ os.environ['HTTPS_PROXY'] = ''
+ os.environ['all_proxy'] = ''
+ try:
+ return func(*args, **kwargs)
+ finally:
+ os.environ['http_proxy'] = http_proxy
+ os.environ['https_proxy'] = https_proxy
+ os.environ['HTTP_PROXY'] = HTTP_PROXY
+ os.environ['HTTPS_PROXY'] = HTTPS_PROXY
+ os.environ['all_proxy'] = all_proxy
+ if NO_PROXY:
+ return wrapper
+ else:
+ return func
diff --git a/openlrm/utils/registry.py b/openlrm/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..421a735f82899c50884cd5b5a27e71757b2eb813
--- /dev/null
+++ b/openlrm/utils/registry.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class Registry:
+ """Registry class"""
+
+ def __init__(self):
+ self._registry = {}
+
+ def register(self, name):
+ """Register a module"""
+ def decorator(cls):
+ assert name not in self._registry, 'Module {} already registered'.format(name)
+ self._registry[name] = cls
+ return cls
+ return decorator
+
+ def __getitem__(self, name):
+ """Get a module"""
+ return self._registry[name]
+
+ def __contains__(self, name):
+ return name in self._registry
diff --git a/openlrm/utils/scheduler.py b/openlrm/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc151d816e2787f37f9bea02b0945e06a933c01
--- /dev/null
+++ b/openlrm/utils/scheduler.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from torch.optim.lr_scheduler import LRScheduler
+from accelerate.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class CosineWarmupScheduler(LRScheduler):
+ def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1):
+ self.warmup_iters = warmup_iters
+ self.max_iters = max_iters
+ self.initial_lr = initial_lr
+ super().__init__(optimizer, last_iter)
+
+ def get_lr(self):
+ logger.debug(f"step count: {self._step_count} | warmup iters: {self.warmup_iters} | max iters: {self.max_iters}")
+ if self._step_count <= self.warmup_iters:
+ return [
+ self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters
+ for base_lr in self.base_lrs]
+ else:
+ cos_iter = self._step_count - self.warmup_iters
+ cos_max_iter = self.max_iters - self.warmup_iters
+ cos_theta = cos_iter / cos_max_iter * math.pi
+ cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs]
+ return cos_lr
diff --git a/openlrm/utils/video.py b/openlrm/utils/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bb79d5760773be29f55bb313f378bfe13ecb320
--- /dev/null
+++ b/openlrm/utils/video.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2023-2024, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import numpy as np
+import imageio
+
+
+def images_to_video(images, output_path, fps, gradio_codec: bool, verbose=False):
+ # images: (T, C, H, W)
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ frames = []
+ for i in range(images.shape[0]):
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
+ assert frame.min() >= 0 and frame.max() <= 255, \
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
+ frames.append(frame)
+ frames = np.stack(frames)
+ if gradio_codec:
+ imageio.mimwrite(output_path, frames, fps=fps, quality=10)
+ else:
+ imageio.mimwrite(output_path, frames, fps=fps, codec='mpeg4', quality=10)
+ if verbose:
+ print(f"Using gradio codec option {gradio_codec}")
+ print(f"Saved video to {output_path}")
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..427ae11e2efff7a65c64ff648771b1121018699d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,16 @@
+lpips
+omegaconf
+transformers
+safetensors
+accelerate
+imageio[ffmpeg]
+PyMCubes
+trimesh
+megfile
+opencv-python
+optimum[onnxruntime-gpu]
+rembg[gpu,cli]
+httpx[socks]
+ninja
+xformers
+git+https://github.com/Baijiong-Lin/LoRA-Torch
\ No newline at end of file