diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d05f48a96da48495a45c4b02288fb020515085b8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,57 @@ +# Ignore Python cache files +**/__pycache__ + +# Ignore compiled Python files +*.pyc + +# Ignore editor-specific files +.vscode/ +.idea/ + +# Ignore operating system files +.DS_Store +Thumbs.db + +# Ignore log files +*.log + +# Ignore temporary and cache files +*.tmp +*.cache + +# Ignore build artifacts +/build/ +/dist/ + +# Ignore virtual environment files +/venv/ +/.venv/ + +# Ignore package manager files +/node_modules/ +/yarn.lock +/package-lock.json + +# Ignore database files +*.db +*.sqlite + +# Ignore secret files +*.secret + +# Ignore compiled binaries +*.exe +*.dll +*.so +*.dylib + +# Ignore backup files +*.bak +*.swp +*.swo +*.~* + + +# Ignore exp result files +dumps/ +exps/ \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2a457f756abfa8faa67d901bbecca93a37f7e6 --- /dev/null +++ b/app.py @@ -0,0 +1,394 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from PIL import Image +import numpy as np +import gradio as gr + + +def assert_input_image(input_front_image, input_back_image): + if input_front_image is None: + raise gr.Error("No front image selected or uploaded!") + if input_back_image is None: + raise gr.Error("No back image selected or uploaded!") + +def prepare_working_dir(): + import tempfile + working_dir = tempfile.TemporaryDirectory() + return working_dir + +def init_preprocessor(): + from openlrm.utils.preprocess import Preprocessor + global preprocessor + preprocessor = Preprocessor() + +def preprocess_fn(image_in_front: np.ndarray, image_in_back: np.ndarray, remove_bg: bool, recenter: bool, working_dir): + # save front image first + image_raw_front = os.path.join(working_dir.name, "raw_front.png") + with Image.fromarray(image_in_front) as img: + img.save(image_raw_front) + image_out_front = os.path.join(working_dir.name, "front/rembg_front.png") + + # save back image first + image_raw_back = os.path.join(working_dir.name, "raw_back.png") + with Image.fromarray(image_in_back) as img: + img.save(image_raw_back) + image_out_back = os.path.join(working_dir.name, "back/rembg_back.png") + + # process the front and back image. + success_front = preprocessor.preprocess(image_path=image_raw_front, save_path=image_out_front, rmbg=remove_bg, recenter=recenter) + success_back = preprocessor.preprocess(image_path=image_raw_back, save_path=image_out_back, rmbg=remove_bg, recenter=recenter) + assert success_front and success_back, f"Failed under preprocess_fn!" + return image_out_front, image_out_back + + +def demo_openlrm(infer_impl): + + def core_fn(image_front: str, image_back: str, source_cam_dist: float, working_dir): + dump_video_path = os.path.join(working_dir.name, "output.mp4") + dump_mesh_path = os.path.join(working_dir.name, "output.ply") + infer_impl( + image_path=image_front, + source_cam_dist=source_cam_dist, + export_video=True, + export_mesh=False, + dump_video_path=dump_video_path, + dump_mesh_path=dump_mesh_path, + image_path_back=image_back, + ) + return dump_video_path + + def example_fn(input_front_image: np.ndarray, input_back_image: np.ndarray): + from gradio.utils import get_cache_folder + working_dir = get_cache_folder() + processed_front_image, processed_back_image = preprocess_fn( + image_in_front=input_front_image, + image_in_back=input_back_image, + remove_bg=True, + recenter=True, + working_dir=working_dir, + ) + video = core_fn( + image_front=processed_front_image, + image_back=processed_back_image, + source_cam_dist=2.0, + working_dir=working_dir, + ) + return processed_front_image, processed_back_image, video + + _TITLE = '''🔥 🔥 🔥 Tailor3D: Customized 3D Assets Editing and Generation with Dual-Side Images''' + + _DESCRIPTION = ''' +
+ + +
+ We propose Tailor3D, a novel pipeline creating customized 3D assets from editable dual-side images and feed-forward reconstruction methods. + + Here we show the final step of Tailor3D. That is given the edited front and beck view of the object. We can produce the 3D object with several seconds. + + Disclaimer: This demo uses `Tailor3D-base-1.1` model with 288x288 rendering resolution here for a quick demonstration. + ''' + + with gr.Blocks(analytics_enabled=False) as demo: + + # HEADERS + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + with gr.Row(): + gr.Markdown(_DESCRIPTION) + + # DISPLAY + with gr.Row(): + gr.Markdown( + """ + ## 🖼️ Input: This is the input front and back images. + """ + ) + with gr.Row(): + with gr.Column(variant='panel', scale=0.2): + with gr.Tabs(elem_id="tailor3d_input_front_image"): + with gr.TabItem('Input Front-view Image'): + with gr.Row(): + input_front_image = gr.Image(label="Input Front Image", image_mode="RGBA", width="auto", sources="upload", type="numpy", elem_id="content_image") + + with gr.Column(variant='panel', scale=0.2): + with gr.Tabs(elem_id="tailor3d_input_back_image"): + with gr.TabItem('Input Back-view Image'): + with gr.Row(): + input_back_image = gr.Image(label="Input Back Image", image_mode="RGBA", width="auto", sources="upload", type="numpy", elem_id="content_image") + with gr.Row(): + gr.Markdown( + """ + ## 🛠️ Preprocess: Remove the background and center the object. + """ + ) + with gr.Row(): + with gr.Column(variant='panel', scale=0.2): + with gr.Tabs(elem_id="tailor3d_processed_image"): + with gr.TabItem('Processed Front-view Image'): + with gr.Row(): + processed_front_image = gr.Image(label="Processed Image", image_mode="RGBA", type="filepath", elem_id="processed_image", width="auto", interactive=False) + with gr.Column(variant='panel', scale=0.2): + with gr.Tabs(elem_id="tailor3d_processed_image"): + with gr.TabItem('Processed Back-view Image'): + with gr.Row(): + processed_back_image = gr.Image(label="Processed Image", image_mode="RGBA", type="filepath", elem_id="processed_image", width="auto", interactive=False) + with gr.Row(): + gr.Markdown( + """ + ## 🚀 Output: The rendering video of the 3D object. + Note that the output is the 3D mesh, for convience, we showcase it through a video that circles around. + """ + ) + with gr.Row(): + with gr.Column(variant='panel', scale=0.2): + with gr.Tabs(elem_id="tailor3d_render_video"): + with gr.TabItem('Rendered Video'): + with gr.Row(): + output_video = gr.Video(label="Rendered Video", format="mp4", width="auto", autoplay=True) + + # SETTING + with gr.Row(): + with gr.Column(variant='panel', scale=1): + with gr.Tabs(elem_id="openlrm_attrs"): + with gr.TabItem('Settings'): + with gr.Column(variant='panel'): + gr.Markdown( + """ + Best Practice: + Centered objects in reasonable sizes. Try adjusting source camera distances. + """ + ) + checkbox_rembg = gr.Checkbox(True, label='Remove background') + checkbox_recenter = gr.Checkbox(True, label='Recenter the object') + slider_cam_dist = gr.Slider(1.0, 3.5, value=2.0, step=0.1, label="Source Camera Distance") + submit = gr.Button('Generate', elem_id="openlrm_generate", variant='primary') + + # EXAMPLES + with gr.Row(): + gr.Markdown( + """ + ## Example in the paper. + ### A. 3D Style Transfer + Here we keep the object ID and just transfer the style.
+ + **Line 1: A pop-mart boy with astronaut, blue, traditional Chinese and grey style.** + """ + ) + with gr.Row(): + examples = [ + ['assets/sample_input/demo/front/boy_astronaut.png', 'assets/sample_input/demo/back/boy_astronaut.png'], + ['assets/sample_input/demo/front/boy_blue.png', 'assets/sample_input/demo/back/boy_blue.png'], + ['assets/sample_input/demo/front/boy_chinese_style.png', 'assets/sample_input/demo/back/boy_chinese_style.png'], + ['assets/sample_input/demo/front/boy_grey_clothes.png', 'assets/sample_input/demo/back/boy_grey_clothes.png'], + ] + + for example in examples: + with gr.Column(scale=1): + gr.Examples( + examples=[example], + inputs=[input_front_image, input_back_image], + outputs=[processed_front_image, processed_back_image, output_video], + fn=example_fn, + cache_examples=bool(os.getenv('SPACE_ID')), + examples_per_page=3, + ) + + # # EXAMPLES + # with gr.Row(): + # gr.Markdown( + # """ + # **Line 2: A LEGO model featuring an astronaut, green and red elements, and a wizard theme.** + # """ + # ) + # with gr.Row(): + # examples = [ + # ['assets/sample_input/demo/front/lego_astronaut.png', 'assets/sample_input/demo/back/lego_astronaut.png'], + # ['assets/sample_input/demo/front/lego_green.png', 'assets/sample_input/demo/back/lego_green.png'], + # ['assets/sample_input/demo/front/lego_red.png', 'assets/sample_input/demo/front/lego_red.png'], + # ['assets/sample_input/demo/front/lego_wizard.png', 'assets/sample_input/demo/back/lego_wizard.png'], + # ] + + # for example in examples: + # with gr.Column(scale=0.3): + # gr.Examples( + # examples=[example], + # inputs=[input_front_image, input_back_image], + # outputs=None, # [processed_image, output_video], + # fn=None, # example_fn, + # cache_examples=bool(os.getenv('SPACE_ID')), + # examples_per_page=3, + # ) + # with gr.Row(): + # gr.Markdown( + # """ + # **Line 3: A marvel boy featuring an Captain America, Ironman and Spiderman, and a Superman theme.** + # """ + # ) + # with gr.Row(): + # examples = [ + # ['assets/sample_input/demo/front/marvel_captain.png', 'assets/sample_input/demo/back/marvel_captain.png'], + # ['assets/sample_input/demo/front/marvel_ironman.png', 'assets/sample_input/demo/front/marvel_ironman.png'], + # ['assets/sample_input/demo/front/marvel_spiderman.png', 'assets/sample_input/demo/back/marvel_spiderman.png'], + # ['assets/sample_input/demo/front/marvel_superman.png', 'assets/sample_input/demo/back/marvel_superman.png'], + # ] + + # for example in examples: + # with gr.Column(scale=0.3): + # gr.Examples( + # examples=[example], + # inputs=[input_front_image, input_back_image], + # outputs=None, # [processed_image, output_video], + # fn=None, # example_fn, + # cache_examples=bool(os.getenv('SPACE_ID')), + # examples_per_page=3, + # ) + # # EXAMPLES + # with gr.Row(): + # gr.Markdown( + # """ + # ### B. 3D Generative Geometry or Pattern Fill + + # Here, we start with a simple object and gradually add various accessories, costumes, or patterns step by step. We only showcase the final effect after multiple rounds of decoration.
+ + # **Line 4: Initial object: sofa, dog, penguin, house.** + # """ + # ) + # with gr.Row(): + # examples = [ + # ['assets/sample_input/demo/front/sofa.png', 'assets/sample_input/demo/back/sofa.png'], + # ['assets/sample_input/demo/front/penguin.png', 'assets/sample_input/demo/back/penguin.png'], + # ['assets/sample_input/demo/front/house.png', 'assets/sample_input/demo/back/house.png'], + # ] + + # for example in examples: + # with gr.Column(scale=0.3): + # gr.Examples( + # examples=[example], + # inputs=[input_front_image, input_back_image], + # outputs=None, # [processed_image, output_video], + # fn=None, # example_fn, + # cache_examples=bool(os.getenv('SPACE_ID')), + # examples_per_page=3, + # ) + + # with gr.Row(): + # gr.Markdown( + # """ + # ### C. 3D Style Fusion + + # We will maintain a consistent front style of the object while continuously changing the back style, blending the two different styles into one object.
+ + # **Line 5: A bird with different back styles.** + # """ + # ) + # with gr.Row(): + # examples = [ + # ['assets/sample_input/demo/front/bird.png', 'assets/sample_input/demo/back/bird.png'], + # ['assets/sample_input/demo/front/bird_brownblue.png', 'assets/sample_input/demo/back/bird_brownblue.png'], + # ['assets/sample_input/demo/front/bird_rainbow.png', 'assets/sample_input/demo/back/bird_rainbow.png'], + # ['assets/sample_input/demo/front/bird_whitered.png', 'assets/sample_input/demo/back/bird_whitered.png'], + # ] + + # for example in examples: + # with gr.Column(scale=0.3): + # gr.Examples( + # examples=[example], + # inputs=[input_front_image, input_back_image], + # outputs=None, # [processed_image, output_video], + # fn=None, # example_fn, + # cache_examples=bool(os.getenv('SPACE_ID')), + # examples_per_page=3, + # ) + + # with gr.Row(): + # gr.Markdown( + # """ + # ### Others + # I vote for kunkun forever, I am really an I-kUN and have heard many of his songs. + # """ + # ) + # with gr.Row(): + # examples = [ + # ['assets/sample_input/demo/front/loopy.png', 'assets/sample_input/demo/back/loopy.png'], + # ['assets/sample_input/demo/front/mario.png', 'assets/sample_input/demo/back/mario.png'], + # ['assets/sample_input/demo/front/armor.png', 'assets/sample_input/demo/back/armor.png'], + # ['assets/sample_input/demo/front/kunkun_law.png', 'assets/sample_input/demo/back/kunkun_law.png'], + # ] + + # for example in examples: + # with gr.Column(scale=0.3): + # gr.Examples( + # examples=[example], + # inputs=[input_front_image, input_back_image], + # outputs=None, # [processed_image, output_video], + # fn=None, # example_fn, + # cache_examples=bool(os.getenv('SPACE_ID')), + # examples_per_page=3, + # ) + + working_dir = gr.State() + submit.click( + fn=assert_input_image, + inputs=[input_front_image, input_back_image], + queue=False, + ).success( + fn=prepare_working_dir, + outputs=[working_dir], + queue=False, + ).success( + fn=preprocess_fn, + inputs=[input_front_image, input_back_image, checkbox_rembg, checkbox_recenter, working_dir], + outputs=[processed_front_image, processed_back_image], + ).success( + fn=core_fn, + inputs=[processed_front_image, processed_back_image, slider_cam_dist, working_dir], + outputs=[output_video], + ) + + demo.queue() + demo.launch() + + +def launch_gradio_app(): + + os.environ.update({ + "APP_ENABLED": "1", + "APP_MODEL_NAME": "alexzyqi/Tailor3D-Base-1.0", + "APP_PRETRAIN_MODEL_NAME": "zxhezexin/openlrm-mix-base-1.1", + "APP_INFER": "./configs/infer-gradio-base.yaml", + "APP_TYPE": "infer.lrm", + "NUMBA_THREADING_LAYER": 'omp', + }) + + from openlrm.runners import REGISTRY_RUNNERS + from openlrm.runners.infer.base_inferrer import Inferrer + InferrerClass : Inferrer = REGISTRY_RUNNERS[os.getenv("APP_TYPE")] + with InferrerClass() as inferrer: + init_preprocessor() + if not bool(os.getenv('SPACE_ID')): + from openlrm.utils.proxy import no_proxy + demo = no_proxy(demo_openlrm) + else: + demo = demo_openlrm + demo(infer_impl=inferrer.infer_single) + + +if __name__ == '__main__': + + launch_gradio_app() diff --git a/assets/sample_input/demo/back/armor.png b/assets/sample_input/demo/back/armor.png new file mode 100644 index 0000000000000000000000000000000000000000..3d4c793f816d58fc0c232513f3be77d7d0508c06 Binary files /dev/null and b/assets/sample_input/demo/back/armor.png differ diff --git a/assets/sample_input/demo/back/bird.png b/assets/sample_input/demo/back/bird.png new file mode 100644 index 0000000000000000000000000000000000000000..71a53ea433ea7604e947a7a435ba741238d9b174 Binary files /dev/null and b/assets/sample_input/demo/back/bird.png differ diff --git a/assets/sample_input/demo/back/bird_brownblue.png b/assets/sample_input/demo/back/bird_brownblue.png new file mode 100644 index 0000000000000000000000000000000000000000..126188e39e6a665a0f2c0f7c2639a228c75c5f61 Binary files /dev/null and b/assets/sample_input/demo/back/bird_brownblue.png differ diff --git a/assets/sample_input/demo/back/bird_rainbow.png b/assets/sample_input/demo/back/bird_rainbow.png new file mode 100644 index 0000000000000000000000000000000000000000..e97ccee2cee23e38d8afb751f0fc5b938df6d6e1 Binary files /dev/null and b/assets/sample_input/demo/back/bird_rainbow.png differ diff --git a/assets/sample_input/demo/back/bird_whitered.png b/assets/sample_input/demo/back/bird_whitered.png new file mode 100644 index 0000000000000000000000000000000000000000..ce0d5858db2680cd173d7f5fe88e2273c36ac33d Binary files /dev/null and b/assets/sample_input/demo/back/bird_whitered.png differ diff --git a/assets/sample_input/demo/back/boy_astronaut.png b/assets/sample_input/demo/back/boy_astronaut.png new file mode 100644 index 0000000000000000000000000000000000000000..a7e107adff8d0bf16d1a98e899a5b7814edbcbbf Binary files /dev/null and b/assets/sample_input/demo/back/boy_astronaut.png differ diff --git a/assets/sample_input/demo/back/boy_blue.png b/assets/sample_input/demo/back/boy_blue.png new file mode 100644 index 0000000000000000000000000000000000000000..79faeb2b3cdf0480f66bfb92fcd42710027dd607 Binary files /dev/null and b/assets/sample_input/demo/back/boy_blue.png differ diff --git a/assets/sample_input/demo/back/boy_chinese_style.png b/assets/sample_input/demo/back/boy_chinese_style.png new file mode 100644 index 0000000000000000000000000000000000000000..0dbe95df5a9053200cae6f49141b2839adad0fdd Binary files /dev/null and b/assets/sample_input/demo/back/boy_chinese_style.png differ diff --git a/assets/sample_input/demo/back/boy_grey_clothes.png b/assets/sample_input/demo/back/boy_grey_clothes.png new file mode 100644 index 0000000000000000000000000000000000000000..5c0603a2b0724b3b95a088513a141318cdf7a1fc Binary files /dev/null and b/assets/sample_input/demo/back/boy_grey_clothes.png differ diff --git a/assets/sample_input/demo/back/house.png b/assets/sample_input/demo/back/house.png new file mode 100644 index 0000000000000000000000000000000000000000..f729175e7a5c5e776acf6a1138c11d472d8bd617 Binary files /dev/null and b/assets/sample_input/demo/back/house.png differ diff --git a/assets/sample_input/demo/back/kunkun_law.png b/assets/sample_input/demo/back/kunkun_law.png new file mode 100644 index 0000000000000000000000000000000000000000..f796d29e1feb8e5a4c60dc57ddf65553c3ffed35 Binary files /dev/null and b/assets/sample_input/demo/back/kunkun_law.png differ diff --git a/assets/sample_input/demo/back/lego_astronaut.png b/assets/sample_input/demo/back/lego_astronaut.png new file mode 100644 index 0000000000000000000000000000000000000000..cbeff41be23598658c8d9a383c5d8bfdbbf70a9f Binary files /dev/null and b/assets/sample_input/demo/back/lego_astronaut.png differ diff --git a/assets/sample_input/demo/back/lego_green.png b/assets/sample_input/demo/back/lego_green.png new file mode 100644 index 0000000000000000000000000000000000000000..cbabe9b9002f29fed85921d8c560e0934c71f8e4 Binary files /dev/null and b/assets/sample_input/demo/back/lego_green.png differ diff --git a/assets/sample_input/demo/back/lego_red.png b/assets/sample_input/demo/back/lego_red.png new file mode 100644 index 0000000000000000000000000000000000000000..c49d948bd40e1f10a8e31cf0e335e7d2fe4c3d14 Binary files /dev/null and b/assets/sample_input/demo/back/lego_red.png differ diff --git a/assets/sample_input/demo/back/lego_wizard.png b/assets/sample_input/demo/back/lego_wizard.png new file mode 100644 index 0000000000000000000000000000000000000000..bd7739fceac0877233dd9f721f1ba36cc260a839 Binary files /dev/null and b/assets/sample_input/demo/back/lego_wizard.png differ diff --git a/assets/sample_input/demo/back/loopy.png b/assets/sample_input/demo/back/loopy.png new file mode 100644 index 0000000000000000000000000000000000000000..f16bf1487e85fdc5b843b494e2706220eeaee150 Binary files /dev/null and b/assets/sample_input/demo/back/loopy.png differ diff --git a/assets/sample_input/demo/back/mario.png b/assets/sample_input/demo/back/mario.png new file mode 100644 index 0000000000000000000000000000000000000000..bea2bda1b0a792a111a9877486c7d72614220412 Binary files /dev/null and b/assets/sample_input/demo/back/mario.png differ diff --git a/assets/sample_input/demo/back/marvel_captain.png b/assets/sample_input/demo/back/marvel_captain.png new file mode 100644 index 0000000000000000000000000000000000000000..ef93ef0b62347bbf8f7b2a3cbc1581e709410c3c Binary files /dev/null and b/assets/sample_input/demo/back/marvel_captain.png differ diff --git a/assets/sample_input/demo/back/marvel_ironman.png b/assets/sample_input/demo/back/marvel_ironman.png new file mode 100644 index 0000000000000000000000000000000000000000..573c584f25b2f6cdea66732d983e7b24f0b63a81 Binary files /dev/null and b/assets/sample_input/demo/back/marvel_ironman.png differ diff --git a/assets/sample_input/demo/back/marvel_spiderman.png b/assets/sample_input/demo/back/marvel_spiderman.png new file mode 100644 index 0000000000000000000000000000000000000000..7f232f0e84d2eafbb51e67794513e46edf3171cf Binary files /dev/null and b/assets/sample_input/demo/back/marvel_spiderman.png differ diff --git a/assets/sample_input/demo/back/marvel_superman.png b/assets/sample_input/demo/back/marvel_superman.png new file mode 100644 index 0000000000000000000000000000000000000000..bb06b4e54d46f7b1bd4a66202317466712108fcd Binary files /dev/null and b/assets/sample_input/demo/back/marvel_superman.png differ diff --git a/assets/sample_input/demo/back/penguin.png b/assets/sample_input/demo/back/penguin.png new file mode 100644 index 0000000000000000000000000000000000000000..277f29c5fa3823baf032fea3b3f5c7fd722973d3 Binary files /dev/null and b/assets/sample_input/demo/back/penguin.png differ diff --git a/assets/sample_input/demo/back/sofa.png b/assets/sample_input/demo/back/sofa.png new file mode 100644 index 0000000000000000000000000000000000000000..b9afb35bdd489f9284986eeca6eaf1c68047259a Binary files /dev/null and b/assets/sample_input/demo/back/sofa.png differ diff --git a/assets/sample_input/demo/front/armor.png b/assets/sample_input/demo/front/armor.png new file mode 100644 index 0000000000000000000000000000000000000000..892f04db2d34c82a4f2dbeac1c9c4665b76f9d45 Binary files /dev/null and b/assets/sample_input/demo/front/armor.png differ diff --git a/assets/sample_input/demo/front/bird.png b/assets/sample_input/demo/front/bird.png new file mode 100644 index 0000000000000000000000000000000000000000..3bae81d9ffd7be1fbc22bbb231978cc985bb9e2b Binary files /dev/null and b/assets/sample_input/demo/front/bird.png differ diff --git a/assets/sample_input/demo/front/bird_brownblue.png b/assets/sample_input/demo/front/bird_brownblue.png new file mode 100644 index 0000000000000000000000000000000000000000..3bae81d9ffd7be1fbc22bbb231978cc985bb9e2b Binary files /dev/null and b/assets/sample_input/demo/front/bird_brownblue.png differ diff --git a/assets/sample_input/demo/front/bird_rainbow.png b/assets/sample_input/demo/front/bird_rainbow.png new file mode 100644 index 0000000000000000000000000000000000000000..3bae81d9ffd7be1fbc22bbb231978cc985bb9e2b Binary files /dev/null and b/assets/sample_input/demo/front/bird_rainbow.png differ diff --git a/assets/sample_input/demo/front/bird_whitered.png b/assets/sample_input/demo/front/bird_whitered.png new file mode 100644 index 0000000000000000000000000000000000000000..3bae81d9ffd7be1fbc22bbb231978cc985bb9e2b Binary files /dev/null and b/assets/sample_input/demo/front/bird_whitered.png differ diff --git a/assets/sample_input/demo/front/boy_astronaut.png b/assets/sample_input/demo/front/boy_astronaut.png new file mode 100644 index 0000000000000000000000000000000000000000..85ee3a2cf0f9b7816a2d6dfbf757dba200a87214 Binary files /dev/null and b/assets/sample_input/demo/front/boy_astronaut.png differ diff --git a/assets/sample_input/demo/front/boy_blue.png b/assets/sample_input/demo/front/boy_blue.png new file mode 100644 index 0000000000000000000000000000000000000000..cb853234bcfb8ba07ef6ec3168068478d052a57b Binary files /dev/null and b/assets/sample_input/demo/front/boy_blue.png differ diff --git a/assets/sample_input/demo/front/boy_chinese_style.png b/assets/sample_input/demo/front/boy_chinese_style.png new file mode 100644 index 0000000000000000000000000000000000000000..612176a643ba5d3e5c8605d688de792866ea90ea Binary files /dev/null and b/assets/sample_input/demo/front/boy_chinese_style.png differ diff --git a/assets/sample_input/demo/front/boy_grey_clothes.png b/assets/sample_input/demo/front/boy_grey_clothes.png new file mode 100644 index 0000000000000000000000000000000000000000..2fd30066d890613a2cfffa6faf7a8f0447525e6e Binary files /dev/null and b/assets/sample_input/demo/front/boy_grey_clothes.png differ diff --git a/assets/sample_input/demo/front/house.png b/assets/sample_input/demo/front/house.png new file mode 100644 index 0000000000000000000000000000000000000000..73b3ef6eaa74cf08b2414ff7c73c3113e9f9bf69 Binary files /dev/null and b/assets/sample_input/demo/front/house.png differ diff --git a/assets/sample_input/demo/front/kunkun_law.png b/assets/sample_input/demo/front/kunkun_law.png new file mode 100644 index 0000000000000000000000000000000000000000..232d6b501d7aec2f4e8191ac2ea7dbbbdf5aa7e4 Binary files /dev/null and b/assets/sample_input/demo/front/kunkun_law.png differ diff --git a/assets/sample_input/demo/front/lego_astronaut.png b/assets/sample_input/demo/front/lego_astronaut.png new file mode 100644 index 0000000000000000000000000000000000000000..d97c780ea4d4d7dd9ea69d3fb1bcecd526d9f761 Binary files /dev/null and b/assets/sample_input/demo/front/lego_astronaut.png differ diff --git a/assets/sample_input/demo/front/lego_green.png b/assets/sample_input/demo/front/lego_green.png new file mode 100644 index 0000000000000000000000000000000000000000..a324af0d7a515a501866e4b37f24253bf70c1a48 Binary files /dev/null and b/assets/sample_input/demo/front/lego_green.png differ diff --git a/assets/sample_input/demo/front/lego_red.png b/assets/sample_input/demo/front/lego_red.png new file mode 100644 index 0000000000000000000000000000000000000000..6b32f7d1241c096b41323c4db0a0f655458ee855 Binary files /dev/null and b/assets/sample_input/demo/front/lego_red.png differ diff --git a/assets/sample_input/demo/front/lego_wizard.png b/assets/sample_input/demo/front/lego_wizard.png new file mode 100644 index 0000000000000000000000000000000000000000..396fd41b11200f2e85c6388a251beba6e1441507 Binary files /dev/null and b/assets/sample_input/demo/front/lego_wizard.png differ diff --git a/assets/sample_input/demo/front/loopy.png b/assets/sample_input/demo/front/loopy.png new file mode 100644 index 0000000000000000000000000000000000000000..ffdf9844cded0267947e810636c77de6b390b6c4 Binary files /dev/null and b/assets/sample_input/demo/front/loopy.png differ diff --git a/assets/sample_input/demo/front/mario.png b/assets/sample_input/demo/front/mario.png new file mode 100644 index 0000000000000000000000000000000000000000..54780d4f3e741e8cad7ca529a64e750a392a159f Binary files /dev/null and b/assets/sample_input/demo/front/mario.png differ diff --git a/assets/sample_input/demo/front/marvel_captain.png b/assets/sample_input/demo/front/marvel_captain.png new file mode 100644 index 0000000000000000000000000000000000000000..753328918caabdda3d61fe044f4df3a6ef447b7f Binary files /dev/null and b/assets/sample_input/demo/front/marvel_captain.png differ diff --git a/assets/sample_input/demo/front/marvel_ironman.png b/assets/sample_input/demo/front/marvel_ironman.png new file mode 100644 index 0000000000000000000000000000000000000000..fb83948dd80b791fdf5032519b5b95c9d467bf6b Binary files /dev/null and b/assets/sample_input/demo/front/marvel_ironman.png differ diff --git a/assets/sample_input/demo/front/marvel_spiderman.png b/assets/sample_input/demo/front/marvel_spiderman.png new file mode 100644 index 0000000000000000000000000000000000000000..e7aa9ce26669ba148bea42f0886d1a6aac4ed316 Binary files /dev/null and b/assets/sample_input/demo/front/marvel_spiderman.png differ diff --git a/assets/sample_input/demo/front/marvel_superman.png b/assets/sample_input/demo/front/marvel_superman.png new file mode 100644 index 0000000000000000000000000000000000000000..32bbce4fc8c6f9a86203a74afda7cfe3bfea892c Binary files /dev/null and b/assets/sample_input/demo/front/marvel_superman.png differ diff --git a/assets/sample_input/demo/front/penguin.png b/assets/sample_input/demo/front/penguin.png new file mode 100644 index 0000000000000000000000000000000000000000..2e94acf9b089832e61d6fd6447822aee56eaa424 Binary files /dev/null and b/assets/sample_input/demo/front/penguin.png differ diff --git a/assets/sample_input/demo/front/sofa.png b/assets/sample_input/demo/front/sofa.png new file mode 100644 index 0000000000000000000000000000000000000000..0baaa67d61cab614baa43551c98f2b191f33fbe6 Binary files /dev/null and b/assets/sample_input/demo/front/sofa.png differ diff --git a/configs/accelerate-train-4gpus.yaml b/configs/accelerate-train-4gpus.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de1bb7c6b6ec93ff568771bc75ca01a3221e43b1 --- /dev/null +++ b/configs/accelerate-train-4gpus.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +# gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +main_process_port: 34567 diff --git a/configs/accelerate-train.yaml b/configs/accelerate-train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0bd37d44db1f3b98c6d06bcd655ece22fc8a38b0 --- /dev/null +++ b/configs/accelerate-train.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +main_process_port: 35567 diff --git a/configs/all-base-2sides.yaml b/configs/all-base-2sides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..127898437d98c410548a110f83b6e904ea3c2eb0 --- /dev/null +++ b/configs/all-base-2sides.yaml @@ -0,0 +1,115 @@ + +experiment: + type: lrm + seed: 42 + parent: gobjaverse-2sides-base + child: 0428_conv_e10 + +model: + camera_embed_dim: 1024 + rendering_samples_per_ray: 96 + transformer_dim: 768 + transformer_layers: 12 + transformer_heads: 12 + triplane_low_res: 32 + triplane_high_res: 64 # useless? + triplane_dim: 48 + encoder_type: dinov2 + encoder_model_name: dinov2_vitb14_reg + encoder_feat_dim: 768 + encoder_freeze: false + model_lora_rank: 4 + conv_fuse: True + +dataset: + subsets: + - name: gobjaverse_delete_tb + root_dirs: + ['data/data_gobjaverse_delete_tb'] + meta_path: + train: data/data_gobjaverse_delete_tb/train.json + val: data/data_gobjaverse_delete_tb/val.json + sample_rate: 1.0 + sample_side_views: 3 + source_image_res: 336 + render_image: + low: 96 + high: 288 + region: 96 + normalize_camera: true + normed_dist_to_center: auto + num_train_workers: 4 + num_val_workers: 2 + pin_mem: true + +train: + mixed_precision: fp16 # REPLACE THIS BASED ON GPU TYPE + find_unused_parameters: false + loss: + pixel_weight: 1.0 + perceptual_weight: 1.0 + tv_weight: 5e-4 + optim: + lr: 4e-4 # most important. + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.95 + clip_grad_norm: 1.0 + scheduler: + type: cosine + warmup_real_iters: 3000 + batch_size: 8 # REPLACE THIS (PER GPU) + accum_steps: 1 # REPLACE THIS + epochs: 10 # REPLACE THIS + debug_global_steps: null + +val: + batch_size: 4 + global_step_period: 1000 + debug_batches: null + +saver: + auto_resume: true + load_model: + type: hugging_face + url: zxhezexin/openlrm-mix-base-1.1/model.safetensors + checkpoint_root: ./exps/checkpoints + checkpoint_global_steps: 1000 + checkpoint_keep_level: 5 + load_model_func_kwargs: + strict: False + +logger: + stream_level: WARNING + log_level: INFO + log_root: ./exps/logs + tracker_root: ./exps/trackers + enable_profiler: false + trackers: + - tensorboard + image_monitor: + train_global_steps: 100 + samples_per_log: 4 + +compile: + suppress_errors: true + print_specializations: true + disable: true + +inferrer: + logger: INFO + hugging_face: False + iteration: 3330 + image_format: True + + source_size: 336 + source_cam_dist: 2.0 + render_size: 288 + render_views: 16 + render_fps: 40 + frame_size: 2 + mesh_size: 384 + mesh_thres: 3.0 + +convert: + global_step: \ No newline at end of file diff --git a/configs/all-large-2sides.yaml b/configs/all-large-2sides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c94397086ad243373bf0b64c8c2a8f03675bf0ed --- /dev/null +++ b/configs/all-large-2sides.yaml @@ -0,0 +1,115 @@ + +experiment: + type: lrm + seed: 42 + parent: gobjaverse-2sides-large + child: 0428_conv_e10 + +model: + camera_embed_dim: 1024 + rendering_samples_per_ray: 128 + transformer_dim: 1024 + transformer_layers: 16 + transformer_heads: 16 + triplane_low_res: 32 # always 32. + triplane_high_res: 64 # useless? + triplane_dim: 80 + encoder_type: dinov2 + encoder_model_name: dinov2_vitb14_reg + encoder_feat_dim: 768 + encoder_freeze: false + model_lora_rank: 4 + conv_fuse: True + +dataset: + subsets: + - name: gobjaverse_delete_tb + root_dirs: + ['data/data_gobjaverse_delete_tb'] + meta_path: + train: data/data_gobjaverse_delete_tb/train.json + val: data/data_gobjaverse_delete_tb/val.json + sample_rate: 1.0 + sample_side_views: 3 + source_image_res: 336 + render_image: + low: 128 + high: 384 + region: 128 + normalize_camera: true + normed_dist_to_center: auto + num_train_workers: 4 + num_val_workers: 2 + pin_mem: true + +train: + mixed_precision: fp16 # REPLACE THIS BASED ON GPU TYPE + find_unused_parameters: false + loss: + pixel_weight: 1.0 + perceptual_weight: 1.0 + tv_weight: 5e-4 + optim: + lr: 4e-4 # most important. + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.95 + clip_grad_norm: 1.0 + scheduler: + type: cosine + warmup_real_iters: 3000 + batch_size: 2 # REPLACE THIS (PER GPU) + accum_steps: 1 # REPLACE THIS + epochs: 10 # REPLACE THIS + debug_global_steps: null + +val: + batch_size: 4 + global_step_period: 1000 + debug_batches: null + +saver: + auto_resume: true + load_model: + type: hugging_face + url: zxhezexin/openlrm-mix-large-1.1/model.safetensors + checkpoint_root: ./exps/checkpoints + checkpoint_global_steps: 1000 + checkpoint_keep_level: 5 + load_model_func_kwargs: + strict: False + +logger: + stream_level: WARNING + log_level: INFO + log_root: ./exps/logs + tracker_root: ./exps/trackers + enable_profiler: false + trackers: + - tensorboard + image_monitor: + train_global_steps: 100 + samples_per_log: 4 + +compile: + suppress_errors: true + print_specializations: true + disable: true + +inferrer: + logger: INFO + hugging_face: False + iteration: 13340 + image_format: True + + source_size: 448 + source_cam_dist: 2.0 + render_size: 384 + render_views: 16 + render_fps: 40 + frame_size: 2 + mesh_size: 1024 + mesh_thres: 1 + +convert: + global_step: \ No newline at end of file diff --git a/configs/all-small-2sides.yaml b/configs/all-small-2sides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a90c858a7213e1a97650ee23dc80b6bea886fff --- /dev/null +++ b/configs/all-small-2sides.yaml @@ -0,0 +1,115 @@ + +experiment: + type: lrm + seed: 42 + parent: gobjaverse-2sides-small + child: 0428_conv + +model: + camera_embed_dim: 1024 + rendering_samples_per_ray: 96 + transformer_dim: 512 + transformer_layers: 12 + transformer_heads: 8 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 32 + encoder_type: dinov2 + encoder_model_name: dinov2_vits14_reg + encoder_feat_dim: 384 + encoder_freeze: false + model_lora_rank: 4 + conv_fuse: True + +dataset: + subsets: + - name: gobjaverse_delete_tb + root_dirs: + ['data/data_gobjaverse_delete_tb'] + meta_path: + train: data/data_gobjaverse_delete_tb/train.json + val: data/data_gobjaverse_delete_tb/val.json + sample_rate: 1.0 + sample_side_views: 3 + source_image_res: 224 + render_image: + low: 64 + high: 192 + region: 64 + normalize_camera: true + normed_dist_to_center: auto + num_train_workers: 4 + num_val_workers: 2 + pin_mem: true + +train: + mixed_precision: fp16 # REPLACE THIS BASED ON GPU TYPE + find_unused_parameters: false + loss: + pixel_weight: 1.0 + perceptual_weight: 1.0 + tv_weight: 5e-4 + optim: + lr: 4e-4 + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.95 + clip_grad_norm: 1.0 + scheduler: + type: cosine + warmup_real_iters: 3000 + batch_size: 16 # REPLACE THIS (PER GPU) + accum_steps: 1 # REPLACE THIS + epochs: 10 # REPLACE THIS + debug_global_steps: null + +val: + batch_size: 4 + global_step_period: 1000 + debug_batches: null + +saver: + auto_resume: true + load_model: + type: hugging_face + url: zxhezexin/openlrm-mix-small-1.1/model.safetensors + checkpoint_root: ./exps/checkpoints + checkpoint_global_steps: 1000 + checkpoint_keep_level: 5 + load_model_func_kwargs: + strict: False + +logger: + stream_level: WARNING + log_level: INFO + log_root: ./exps/logs + tracker_root: ./exps/trackers + enable_profiler: false + trackers: + - tensorboard + image_monitor: + train_global_steps: 100 + samples_per_log: 4 + +compile: + suppress_errors: true + print_specializations: true + disable: true + +inferrer: + logger: INFO + hugging_face: False + iteration: 1660 + image_format: True + + source_size: 224 + source_cam_dist: 2.0 + render_size: 192 + render_views: 16 + render_fps: 40 + frame_size: 2 + mesh_size: 384 + mesh_thres: 3.0 + +convert: + global_step: \ No newline at end of file diff --git a/configs/infer-gradio-base.yaml b/configs/infer-gradio-base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d5c0aeed8b3511ea66058f978d977b02fc0b121 --- /dev/null +++ b/configs/infer-gradio-base.yaml @@ -0,0 +1,16 @@ +double_sided: True + +model: + conv_fuse: True + +inferrer: + hugging_face: True + source_size: 336 + render_size: 288 + render_views: 100 + render_fps: 25 + frame_size: 2 + mesh_size: 384 + mesh_thres: 3.0 + double_sided: True + image_format: False \ No newline at end of file diff --git a/configs/infer-gradio-large.yaml b/configs/infer-gradio-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0fadc8e0bd3cf09e4250033bab74b126f1c1e616 --- /dev/null +++ b/configs/infer-gradio-large.yaml @@ -0,0 +1,16 @@ +double_sided: True + +model: + conv_fuse: True + +inferrer: + hugging_face: True + source_size: 448 + render_size: 384 + render_views: 100 + render_fps: 25 + frame_size: 2 + mesh_size: 1024 + mesh_thres: 3.0 + double_sided: True + image_format: False diff --git a/openlrm/__init__.py b/openlrm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6 --- /dev/null +++ b/openlrm/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Empty diff --git a/openlrm/datasets/__init__.py b/openlrm/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..323127c7d93f0a57f90cc8649ee2a67b6b630762 --- /dev/null +++ b/openlrm/datasets/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .mixer import MixerDataset diff --git a/openlrm/datasets/back_transform/back_transform.py b/openlrm/datasets/back_transform/back_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..c498177e969e74915cbb4b03fc031054f3ea7efa --- /dev/null +++ b/openlrm/datasets/back_transform/back_transform.py @@ -0,0 +1,24 @@ +import torch +import torchvision.transforms as transforms + +# Add gauusian noise +class AddGaussianNoise: + def __init__(self, mean=0., std=1.): + self.std = std + self.mean = mean + + def __call__(self, img): + return img + torch.randn(img.size()) * self.std + self.mean + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + + +def transform_back_image(): + return transforms.Compose([ + transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), + transforms.RandomRotation((-5, 5)), + transforms.Pad(padding=10, fill=0, padding_mode='constant'), + transforms.ToTensor(), + AddGaussianNoise(0., 0.1) + ]) diff --git a/openlrm/datasets/base.py b/openlrm/datasets/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f327fb147cafbfeaa7455fd8e8f35c9591b2d018 --- /dev/null +++ b/openlrm/datasets/base.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +import json +import numpy as np +import torch +from PIL import Image +from megfile import smart_open, smart_path_join, smart_exists + + +class BaseDataset(torch.utils.data.Dataset, ABC): + def __init__(self, root_dirs: list[str], meta_path: str): + super().__init__() + self.root_dirs = root_dirs + self.uids = self._load_uids(meta_path) + + def __len__(self): + return len(self.uids) + + @abstractmethod + def inner_get_item(self, idx): + pass + + def __getitem__(self, idx): + try: + return self.inner_get_item(idx) + except Exception as e: + print(f"[DEBUG-DATASET] Error when loading {self.uids[idx]}") + # return self.__getitem__(idx+1) + raise e + + @staticmethod + def _load_uids(meta_path: str): + # meta_path is a json file + with open(meta_path, 'r') as f: + uids = json.load(f) + return uids + + @staticmethod + def _load_rgba_image(file_path, bg_color: float = 1.0): + ''' Load and blend RGBA image to RGB with certain background, 0-1 scaled ''' + rgba = np.array(Image.open(smart_open(file_path, 'rb'))) + rgba = torch.from_numpy(rgba).float() / 255.0 + rgba = rgba.permute(2, 0, 1).unsqueeze(0) + rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :]) + rgba[:, :3, ...] * rgba[:, 3:, ...] + (1 - rgba[:, 3:, ...]) + return rgb + + @staticmethod + def _locate_datadir(root_dirs, uid, locator: str): + for root_dir in root_dirs: + datadir = smart_path_join(root_dir, uid, locator) + if smart_exists(datadir): + return root_dir + raise FileNotFoundError(f"Cannot find valid data directory for uid {uid}") diff --git a/openlrm/datasets/cam_utils.py b/openlrm/datasets/cam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70653ae2a7f612714f729c73f45e826109b7e0ff --- /dev/null +++ b/openlrm/datasets/cam_utils.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import torch + +""" +R: (N, 3, 3) +T: (N, 3) +E: (N, 4, 4) +vector: (N, 3) +""" + + +def compose_extrinsic_R_T(R: torch.Tensor, T: torch.Tensor): + """ + Compose the standard form extrinsic matrix from R and T. + Batched I/O. + """ + RT = torch.cat((R, T.unsqueeze(-1)), dim=-1) + return compose_extrinsic_RT(RT) + + +def compose_extrinsic_RT(RT: torch.Tensor): + """ + Compose the standard form extrinsic matrix from RT. + Batched I/O. + """ + return torch.cat([ + RT, + torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1) + ], dim=1) + + +def decompose_extrinsic_R_T(E: torch.Tensor): + """ + Decompose the standard extrinsic matrix into R and T. + Batched I/O. + """ + RT = decompose_extrinsic_RT(E) + return RT[:, :, :3], RT[:, :, 3] + + +def decompose_extrinsic_RT(E: torch.Tensor): + """ + Decompose the standard extrinsic matrix into RT. + Batched I/O. + """ + return E[:, :3, :] + + +def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False): + assert normed_dist_to_center is not None + pivotal_pose = compose_extrinsic_RT(poses[:1]) + dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \ + if normed_dist_to_center == 'auto' else normed_dist_to_center + + # compute camera norm (new version) + canonical_camera_extrinsics = torch.tensor([[ + [1, 0, 0, 0], + [0, 0, -1, -dist_to_center], + [0, 1, 0, 0], + [0, 0, 0, 1], + ]], dtype=torch.float32) + pivotal_pose_inv = torch.inverse(pivotal_pose) + camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv) + + # normalize all views + poses = compose_extrinsic_RT(poses) + poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses) + poses = decompose_extrinsic_RT(poses) + + if ret_transform: + return poses, camera_norm_matrix.squeeze(dim=0) + return poses + + +def get_normalized_camera_intrinsics(intrinsics: torch.Tensor): + """ + intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] + Return batched fx, fy, cx, cy + """ + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1] + cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1] + width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1] + fx, fy = fx / width, fy / height + cx, cy = cx / width, cy / height + return fx, fy, cx, cy + + +def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor): + """ + RT: (N, 3, 4) + intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] + """ + fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics) + return torch.cat([ + RT.reshape(-1, 12), + fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1), + ], dim=-1) + + +def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor): + """ + RT: (N, 3, 4) + intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] + """ + E = compose_extrinsic_RT(RT) + fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics) + I = torch.stack([ + torch.stack([fx, torch.zeros_like(fx), cx], dim=-1), + torch.stack([torch.zeros_like(fy), fy, cy], dim=-1), + torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1), + ], dim=1) + return torch.cat([ + E.reshape(-1, 16), + I.reshape(-1, 9), + ], dim=-1) + + +def center_looking_at_camera_pose( + camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None, + device: torch.device = torch.device('cpu'), + ): + """ + camera_position: (M, 3) + look_at: (3) + up_world: (3) + return: (M, 3, 4) + """ + # by default, looking at the origin and world up is pos-z + if look_at is None: + look_at = torch.tensor([0, 0, 0], dtype=torch.float32, device=device) + if up_world is None: + up_world = torch.tensor([0, 0, 1], dtype=torch.float32, device=device) + look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) + up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) + + z_axis = camera_position - look_at + z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True) + x_axis = torch.cross(up_world, z_axis) + x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True) + y_axis = torch.cross(z_axis, x_axis) + y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True) + extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) + return extrinsics + + +def surrounding_views_linspace(n_views: int, radius: float = 2.0, height: float = 0.8, device: torch.device = torch.device('cpu')): + """ + n_views: number of surrounding views + radius: camera dist to center + height: height of the camera + return: (M, 3, 4) + """ + assert n_views > 0 + assert radius > 0 + + theta = torch.linspace(-torch.pi / 2, 3 * torch.pi / 2, n_views, device=device) + projected_radius = math.sqrt(radius ** 2 - height ** 2) + x = torch.cos(theta) * projected_radius + y = torch.sin(theta) * projected_radius + z = torch.full((n_views,), height, device=device) + + camera_positions = torch.stack([x, y, z], dim=1) + extrinsics = center_looking_at_camera_pose(camera_positions, device=device) + + return extrinsics + + +def create_intrinsics( + f: float, + c: float = None, cx: float = None, cy: float = None, + w: float = 1., h: float = 1., + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device('cpu'), + ): + """ + return: (3, 2) + """ + fx = fy = f + if c is not None: + assert cx is None and cy is None, "c and cx/cy cannot be used together" + cx = cy = c + else: + assert cx is not None and cy is not None, "cx/cy must be provided when c is not provided" + fx, fy, cx, cy, w, h = fx/w, fy/h, cx/w, cy/h, 1., 1. + intrinsics = torch.tensor([ + [fx, fy], + [cx, cy], + [w, h], + ], dtype=dtype, device=device) + return intrinsics diff --git a/openlrm/datasets/gobjaverse.py b/openlrm/datasets/gobjaverse.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e1997a6e658c06324bdbb6e19e218af30d4014 --- /dev/null +++ b/openlrm/datasets/gobjaverse.py @@ -0,0 +1,180 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from typing import Union +import random +import numpy as np +import torch +from megfile import smart_path_join, smart_open + +from .cam_utils import build_camera_standard, build_camera_principle, camera_normalization_objaverse +from ..utils.proxy import no_proxy +from .objaverse import ObjaverseDataset +from .back_transform.back_transform import transform_back_image + +from PIL import Image +from torchvision import transforms + +__all__ = ['GobjaverseDataset'] + +def opposite_view(i): + if 0 <= i <= 24: + return (i + 12) % 24 + elif 27 <= i <= 39: + return ((i - 27) + 6) % 12 + 27 + else: + raise ValueError("Input number must be between 0-24 or 27-39.") + +def get_random_views(rgba_dir, num_views=4): + all_files = [f for f in os.listdir(rgba_dir) if f.endswith('.png')] + view_numbers = [int(os.path.splitext(f)[0]) for f in all_files] + selected_views = random.sample(view_numbers, num_views) + return np.array(selected_views) + +class GobjaverseDataset(ObjaverseDataset): + + def __init__(self, root_dirs: list[str], meta_path: str, + sample_side_views: int, + render_image_res_low: int, render_image_res_high: int, render_region_size: int, + source_image_res: int, normalize_camera: bool, + normed_dist_to_center: Union[float, str] = None, num_all_views: int = 32): + super().__init__( + root_dirs, meta_path, + sample_side_views, + render_image_res_low, + render_image_res_high, + render_region_size, + source_image_res, + normalize_camera, + normed_dist_to_center, + num_all_views, + ) + + self.back_transforms = transform_back_image() + + # This is for gobjaverse and objaverse_mengchen + @staticmethod + def _load_pose_txt(file_path): # load .txt #!!! + with open(file_path, 'r') as file: + lines = file.readlines() + pose_data = np.array([list(map(float, line.split())) for line in lines], dtype=np.float32) + pose = torch.from_numpy(pose_data).reshape(4, 4) # [1. 16] -> [4, 4] -> [3, 4] + opengl2opencv = np.array([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1] + ], dtype=np.float32) + # This is the camera pose in OpenCV format. + pose = np.matmul(pose, opengl2opencv) + return pose[:3, :] # [4, 4] -> [3, 4] + + @staticmethod + def _load_rgba_image_transform(file_path, bg_color: float = 1.0, extra_transforms=None): #!!! + ''' Load and blend RGBA image to RGB with certain background, 0-1 scaled ''' + rgba = np.array(Image.open(smart_open(file_path, 'rb')) ) # (512, 512, 4) + rgba = torch.from_numpy(rgba).float() / 255.0 + rgba = rgba.permute(2, 0, 1).unsqueeze(0) + rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :]) + if extra_transforms is not None: + rgb = extra_transforms( + transforms.ToPILImage()(rgb.squeeze()) + ).unsqueeze(0) + return rgb # [1, 3, 512, 512] + + @no_proxy + def inner_get_item(self, idx): + """ + Loaded contents: + rgbs: [M, 3, H, W] + poses: [M, 3, 4], [R|t] + intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]] + """ + uid = self.uids[idx] + root_dir = self._locate_datadir(self.root_dirs, uid, locator="pose") + + pose_dir = os.path.join(root_dir, uid, 'pose') + rgba_dir = os.path.join(root_dir, uid, 'rgb') + + # only one intrinsics + intrinsics = torch.tensor([[384, 384], [256, 256], [512, 512]], dtype=torch.float) + + # sample views (incl. source view and side views) + sample_views = get_random_views(rgba_dir, num_views=self.sample_side_views) + source_image_view_back = opposite_view(sample_views[0]) + sample_views = np.insert(sample_views, 1, source_image_view_back) + + poses, rgbs, bg_colors = [], [], [] + source_image = None + for view in sample_views: + pose_path = smart_path_join(pose_dir, f'{view:03d}.txt') + rgba_path = smart_path_join(rgba_dir, f'{view:03d}.png') + pose = self._load_pose_txt(pose_path) #!!! + bg_color = random.choice([0.0, 0.5, 1.0]) + rgb = self._load_rgba_image(rgba_path, bg_color=bg_color) + poses.append(pose) + rgbs.append(rgb) + bg_colors.append(bg_color) + if source_image is None: + source_image = self._load_rgba_image(rgba_path, bg_color=1.0) + assert source_image is not None, "Really bad luck!" + poses = torch.stack(poses, dim=0) + rgbs = torch.cat(rgbs, dim=0) + + #!!! lora for the backview + source_image_back = self._load_rgba_image_transform(smart_path_join(rgba_dir, f'{sample_views[1]:03d}.png'), bg_color=bg_color) + + if self.normalize_camera: + poses = camera_normalization_objaverse(self.normed_dist_to_center, poses) + + # build source and target camera features + source_camera = build_camera_principle(poses[:1], intrinsics.unsqueeze(0)).squeeze(0) + render_camera = build_camera_standard(poses, intrinsics.repeat(poses.shape[0], 1, 1)) + + # adjust source image resolution + source_image = torch.nn.functional.interpolate( + source_image, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0) + source_image = torch.clamp(source_image, 0, 1) + + #!!! adjust source_image_back resolution + source_image_back = torch.nn.functional.interpolate( + source_image_back, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0) + source_image_back = torch.clamp(source_image_back, 0, 1) + + # adjust render image resolution and sample intended rendering region + render_image_res = np.random.randint(self.render_image_res_low, self.render_image_res_high + 1) + render_image = torch.nn.functional.interpolate( + rgbs, size=(render_image_res, render_image_res), mode='bicubic', align_corners=True) + render_image = torch.clamp(render_image, 0, 1) + anchors = torch.randint( + 0, render_image_res - self.render_region_size + 1, size=(self.sample_side_views + 1, 2)) + crop_indices = torch.arange(0, self.render_region_size, device=render_image.device) + index_i = (anchors[:, 0].unsqueeze(1) + crop_indices).view(-1, self.render_region_size, 1) + index_j = (anchors[:, 1].unsqueeze(1) + crop_indices).view(-1, 1, self.render_region_size) + batch_indices = torch.arange(self.sample_side_views + 1, device=render_image.device).view(-1, 1, 1) + cropped_render_image = render_image[batch_indices, :, index_i, index_j].permute(0, 3, 1, 2) + + return { + 'uid': uid, + 'source_camera': source_camera, + 'render_camera': render_camera, + 'source_image': source_image, + 'render_image': cropped_render_image, + 'source_image_back': source_image_back, #!!! + 'render_anchors': anchors, + 'render_full_resolutions': torch.tensor([[render_image_res]], dtype=torch.float32).repeat(self.sample_side_views + 1, 1), + 'render_bg_colors': torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1), + } diff --git a/openlrm/datasets/mixer.py b/openlrm/datasets/mixer.py new file mode 100644 index 0000000000000000000000000000000000000000..8455abf54be7884928fb102dbd610aa4f927b634 --- /dev/null +++ b/openlrm/datasets/mixer.py @@ -0,0 +1,72 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from functools import partial +import torch + +__all__ = ['MixerDataset'] + + +class MixerDataset(torch.utils.data.Dataset): + + def __init__(self, + split: str, + subsets: list[dict], + **dataset_kwargs, + ): + self.subsets = [ + self._dataset_fn(subset, split)(**dataset_kwargs) + for subset in subsets + ] + self.virtual_lens = [ + math.ceil(subset_config['sample_rate'] * len(subset_obj)) + for subset_config, subset_obj in zip(subsets, self.subsets) + ] + + @staticmethod + def _dataset_fn(subset_config: dict, split: str): + name = subset_config['name'] + + dataset_cls = None + if name == "objaverse": + from .objaverse import ObjaverseDataset + dataset_cls = ObjaverseDataset + elif name == 'gobjaverse_delete_tb': + from .gobjaverse import GobjaverseDataset + dataset_cls = GobjaverseDataset + # elif name == 'mvimgnet': + # from .mvimgnet import MVImgNetDataset + # dataset_cls = MVImgNetDataset + else: + raise NotImplementedError(f"Dataset {name} not implemented") + + return partial( + dataset_cls, + root_dirs=subset_config['root_dirs'], + meta_path=subset_config['meta_path'][split], + ) + + def __len__(self): + return sum(self.virtual_lens) + + def __getitem__(self, idx): + subset_idx = 0 + virtual_idx = idx + while virtual_idx >= self.virtual_lens[subset_idx]: + virtual_idx -= self.virtual_lens[subset_idx] + subset_idx += 1 + real_idx = virtual_idx % len(self.subsets[subset_idx]) + return self.subsets[subset_idx][real_idx] diff --git a/openlrm/datasets/objaverse.py b/openlrm/datasets/objaverse.py new file mode 100644 index 0000000000000000000000000000000000000000..728a5203b3f6b1046aa7ac126c5e86e26fd90e07 --- /dev/null +++ b/openlrm/datasets/objaverse.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from typing import Union +import random +import numpy as np +import torch +from megfile import smart_path_join, smart_open + +from .base import BaseDataset +from .cam_utils import build_camera_standard, build_camera_principle, camera_normalization_objaverse +from ..utils.proxy import no_proxy + +__all__ = ['ObjaverseDataset'] + + +class ObjaverseDataset(BaseDataset): + + def __init__(self, root_dirs: list[str], meta_path: str, + sample_side_views: int, + render_image_res_low: int, render_image_res_high: int, render_region_size: int, + source_image_res: int, normalize_camera: bool, + normed_dist_to_center: Union[float, str] = None, num_all_views: int = 32): + super().__init__(root_dirs, meta_path) + self.sample_side_views = sample_side_views # 3 + self.render_image_res_low = render_image_res_low # 64 + self.render_image_res_high = render_image_res_high # 192 + self.render_region_size = render_region_size # 64 + self.source_image_res = source_image_res # 224s + self.normalize_camera = normalize_camera # True + self.normed_dist_to_center = normed_dist_to_center # 'auto' + self.num_all_views = num_all_views + + @staticmethod + def _load_pose(file_path): + pose = np.load(smart_open(file_path, 'rb')) + pose = torch.from_numpy(pose).float() + return pose + + @no_proxy + def inner_get_item(self, idx): + """ + Loaded contents: + rgbs: [M, 3, H, W] + poses: [M, 3, 4], [R|t] + intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]] + """ + uid = self.uids[idx] + root_dir = self._locate_datadir(self.root_dirs, uid, locator="intrinsics.npy") + + pose_dir = os.path.join(root_dir, uid, 'pose') + rgba_dir = os.path.join(root_dir, uid, 'rgba') + intrinsics_path = os.path.join(root_dir, uid, 'intrinsics.npy') + + # load intrinsics + intrinsics = np.load(smart_open(intrinsics_path, 'rb')) + intrinsics = torch.from_numpy(intrinsics).float() + + # sample views (incl. source view and side views) + sample_views = np.random.choice(range(self.num_all_views), self.sample_side_views + 1, replace=False) + poses, rgbs, bg_colors = [], [], [] + source_image = None + for view in sample_views: + pose_path = smart_path_join(pose_dir, f'{view:03d}.npy') + rgba_path = smart_path_join(rgba_dir, f'{view:03d}.png') + pose = self._load_pose(pose_path) + bg_color = random.choice([0.0, 0.5, 1.0]) + rgb = self._load_rgba_image(rgba_path, bg_color=bg_color) + poses.append(pose) + rgbs.append(rgb) + bg_colors.append(bg_color) + if source_image is None: + source_image = self._load_rgba_image(rgba_path, bg_color=1.0) + assert source_image is not None, "Really bad luck!" + poses = torch.stack(poses, dim=0) + rgbs = torch.cat(rgbs, dim=0) + + if self.normalize_camera: + poses = camera_normalization_objaverse(self.normed_dist_to_center, poses) + + # build source and target camera features + source_camera = build_camera_principle(poses[:1], intrinsics.unsqueeze(0)).squeeze(0) + render_camera = build_camera_standard(poses, intrinsics.repeat(poses.shape[0], 1, 1)) + + # adjust source image resolution + source_image = torch.nn.functional.interpolate( + source_image, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0) + source_image = torch.clamp(source_image, 0, 1) + + # adjust render image resolution and sample intended rendering region + render_image_res = np.random.randint(self.render_image_res_low, self.render_image_res_high + 1) + render_image = torch.nn.functional.interpolate( + rgbs, size=(render_image_res, render_image_res), mode='bicubic', align_corners=True) + render_image = torch.clamp(render_image, 0, 1) + anchors = torch.randint( + 0, render_image_res - self.render_region_size + 1, size=(self.sample_side_views + 1, 2)) + crop_indices = torch.arange(0, self.render_region_size, device=render_image.device) + index_i = (anchors[:, 0].unsqueeze(1) + crop_indices).view(-1, self.render_region_size, 1) + index_j = (anchors[:, 1].unsqueeze(1) + crop_indices).view(-1, 1, self.render_region_size) + batch_indices = torch.arange(self.sample_side_views + 1, device=render_image.device).view(-1, 1, 1) + cropped_render_image = render_image[batch_indices, :, index_i, index_j].permute(0, 3, 1, 2) + + return { + 'uid': uid, + 'source_camera': source_camera, + 'render_camera': render_camera, + 'source_image': source_image, + 'render_image': cropped_render_image, + 'render_anchors': anchors, + 'render_full_resolutions': torch.tensor([[render_image_res]], dtype=torch.float32).repeat(self.sample_side_views + 1, 1), + 'render_bg_colors': torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1), + } diff --git a/openlrm/launch.py b/openlrm/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..efb425c93344597a291061a36b8eb94a9c0858ef --- /dev/null +++ b/openlrm/launch.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse + +from openlrm.runners import REGISTRY_RUNNERS + + +def main(): + + parser = argparse.ArgumentParser(description='OpenLRM launcher') + parser.add_argument('runner', type=str, help='Runner to launch') + args, unknown = parser.parse_known_args() + + if args.runner not in REGISTRY_RUNNERS: + raise ValueError('Runner {} not found'.format(args.runner)) + + RunnerClass = REGISTRY_RUNNERS[args.runner] + with RunnerClass() as runner: + runner.run() + + +if __name__ == '__main__': + main() diff --git a/openlrm/losses/__init__.py b/openlrm/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8da8292b9982cddbfaf84ad3ea74b4bfa9925d --- /dev/null +++ b/openlrm/losses/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .pixelwise import * +from .perceptual import * +from .tvloss import * diff --git a/openlrm/losses/perceptual.py b/openlrm/losses/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..5eead0d1a207e1863598d3400a4a42bd40549114 --- /dev/null +++ b/openlrm/losses/perceptual.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +__all__ = ['LPIPSLoss'] + + +class LPIPSLoss(nn.Module): + """ + Compute LPIPS loss between two images. + """ + + def __init__(self, device, prefech: bool = False): + super().__init__() + self.device = device + self.cached_models = {} + if prefech: + self.prefetch_models() + + def _get_model(self, model_name: str): + if model_name not in self.cached_models: + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning) + import lpips + _model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device) + _model = torch.compile(_model) + self.cached_models[model_name] = _model + return self.cached_models[model_name] + + def prefetch_models(self): + _model_names = ['alex', 'vgg'] + for model_name in _model_names: + self._get_model(model_name) + + def forward(self, x, y, is_training: bool = True): + """ + Assume images are 0-1 scaled and channel first. + + Args: + x: [N, M, C, H, W] + y: [N, M, C, H, W] + is_training: whether to use VGG or AlexNet. + + Returns: + Mean-reduced LPIPS loss across batch. + """ + model_name = 'vgg' if is_training else 'alex' + loss_fn = self._get_model(model_name) + N, M, C, H, W = x.shape + x = x.reshape(N*M, C, H, W) + y = y.reshape(N*M, C, H, W) + image_loss = loss_fn(x, y, normalize=True).mean(dim=[1, 2, 3]) + batch_loss = image_loss.reshape(N, M).mean(dim=1) + all_loss = batch_loss.mean() + return all_loss diff --git a/openlrm/losses/pixelwise.py b/openlrm/losses/pixelwise.py new file mode 100644 index 0000000000000000000000000000000000000000..f936d9960041e49baf2ab1334e9639c219212ec2 --- /dev/null +++ b/openlrm/losses/pixelwise.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +__all__ = ['PixelLoss'] + + +class PixelLoss(nn.Module): + """ + Pixel-wise loss between two images. + """ + + def __init__(self, option: str = 'mse'): + super().__init__() + self.loss_fn = self._build_from_option(option) + + @staticmethod + def _build_from_option(option: str, reduction: str = 'none'): + if option == 'mse': + return nn.MSELoss(reduction=reduction) + elif option == 'l1': + return nn.L1Loss(reduction=reduction) + else: + raise NotImplementedError(f'Unknown pixel loss option: {option}') + + @torch.compile + def forward(self, x, y): + """ + Assume images are channel first. + + Args: + x: [N, M, C, H, W] + y: [N, M, C, H, W] + + Returns: + Mean-reduced pixel loss across batch. + """ + N, M, C, H, W = x.shape + x = x.reshape(N*M, C, H, W) + y = y.reshape(N*M, C, H, W) + image_loss = self.loss_fn(x, y).mean(dim=[1, 2, 3]) + batch_loss = image_loss.reshape(N, M).mean(dim=1) + all_loss = batch_loss.mean() + return all_loss diff --git a/openlrm/losses/tvloss.py b/openlrm/losses/tvloss.py new file mode 100644 index 0000000000000000000000000000000000000000..77a13b69b6f9fcacc38940373bf8159b3cf61459 --- /dev/null +++ b/openlrm/losses/tvloss.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +__all__ = ['TVLoss'] + + +class TVLoss(nn.Module): + """ + Total variance loss. + """ + + def __init__(self): + super().__init__() + + def numel_excluding_first_dim(self, x): + return x.numel() // x.shape[0] + + @torch.compile + def forward(self, x): + """ + Assume batched and channel first with inner sizes. + + Args: + x: [N, M, C, H, W] + + Returns: + Mean-reduced TV loss with element-level scaling. + """ + N, M, C, H, W = x.shape + x = x.reshape(N*M, C, H, W) + diff_i = x[..., 1:, :] - x[..., :-1, :] + diff_j = x[..., :, 1:] - x[..., :, :-1] + div_i = self.numel_excluding_first_dim(diff_i) + div_j = self.numel_excluding_first_dim(diff_j) + tv_i = diff_i.pow(2).sum(dim=[1,2,3]) / div_i + tv_j = diff_j.pow(2).sum(dim=[1,2,3]) / div_j + tv = tv_i + tv_j + batch_tv = tv.reshape(N, M).mean(dim=1) + all_tv = batch_tv.mean() + return all_tv diff --git a/openlrm/models/__init__.py b/openlrm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abde0f5ffc44d03e5bba48e09eb442e73eeb2adc --- /dev/null +++ b/openlrm/models/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .modeling_lrm import ModelLRM + + +model_dict = { + 'lrm': ModelLRM, +} diff --git a/openlrm/models/block.py b/openlrm/models/block.py new file mode 100644 index 0000000000000000000000000000000000000000..a5eb00b873fbf05928c078c9d379c8728628150d --- /dev/null +++ b/openlrm/models/block.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn +import loratorch as lora +from .modulate import ModLN + + +class BasicBlock(nn.Module): + """ + Transformer block that is in its simplest form. + Designed for PF-LRM architecture. + """ + # Block contains a self-attention layer and an MLP + def __init__(self, inner_dim: int, num_heads: int, eps: float, + attn_drop: float = 0., attn_bias: bool = False, + mlp_ratio: float = 4., mlp_drop: float = 0.): + super().__init__() + self.norm1 = nn.LayerNorm(inner_dim, eps=eps) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm2 = nn.LayerNorm(inner_dim, eps=eps) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x): + # x: [N, L, D] + before_sa = self.norm1(x) + x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] + x = x + self.mlp(self.norm2(x)) + return x + + +class ConditionBlock(nn.Module): + """ + Transformer block that takes in a cross-attention condition. + Designed for SparseLRM architecture. + """ + # Block contains a cross-attention layer, a self-attention layer, and an MLP + def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float, + attn_drop: float = 0., attn_bias: bool = False, + mlp_ratio: float = 4., mlp_drop: float = 0., + lora_rank: int = 0): + super().__init__() + self.norm1 = nn.LayerNorm(inner_dim, eps=eps) + self.cross_attn = lora.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, + dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank) + self.norm2 = nn.LayerNorm(inner_dim, eps=eps) + self.self_attn = lora.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank) + self.norm3 = nn.LayerNorm(inner_dim, eps=eps) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x, cond): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0] + before_sa = self.norm2(x) + x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] + x = x + self.mlp(self.norm3(x)) + return x + + +class ConditionModulationBlock(nn.Module): + """ + Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. + Designed for raw LRM architecture. + """ + # Block contains a cross-attention layer, a self-attention layer, and an MLP + def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float, + attn_drop: float = 0., attn_bias: bool = False, + mlp_ratio: float = 4., mlp_drop: float = 0., + lora_rank: int = 0): + super().__init__() + self.norm1 = ModLN(inner_dim, mod_dim, eps) + self.cross_attn = lora.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, + dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank) + self.norm2 = ModLN(inner_dim, mod_dim, eps) + self.self_attn = lora.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True, r=lora_rank) + self.norm3 = ModLN(inner_dim, mod_dim, eps) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x, cond, mod): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + # mod: [N, D_mod] + x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0] + before_sa = self.norm2(x, mod) + x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] + x = x + self.mlp(self.norm3(x, mod)) + return x diff --git a/openlrm/models/embedder.py b/openlrm/models/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..379721cf5c146cb29aca7695cff8558b0d23673c --- /dev/null +++ b/openlrm/models/embedder.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class CameraEmbedder(nn.Module): + """ + Embed camera features to a high-dimensional vector. + + Reference: + DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27 + """ + def __init__(self, raw_dim: int, embed_dim: int): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(raw_dim, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim), + ) + + @torch.compile + def forward(self, x): + return self.mlp(x) diff --git a/openlrm/models/encoders/__init__.py b/openlrm/models/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6 --- /dev/null +++ b/openlrm/models/encoders/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Empty diff --git a/openlrm/models/encoders/dino_wrapper.py b/openlrm/models/encoders/dino_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..77f475781af22630f3989ff912283fdc2e1b13b4 --- /dev/null +++ b/openlrm/models/encoders/dino_wrapper.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from transformers import ViTImageProcessor, ViTModel +from accelerate.logging import get_logger + + +logger = get_logger(__name__) + + +class DinoWrapper(nn.Module): + """ + Dino v1 wrapper using huggingface transformer implementation. + """ + def __init__(self, model_name: str, freeze: bool = True): + super().__init__() + self.model, self.processor = self._build_dino(model_name) + if freeze: + self._freeze() + + @torch.compile + def forward_model(self, inputs): + return self.model(**inputs, interpolate_pos_encoding=True) + + def forward(self, image): + # image: [N, C, H, W], on cpu + # RGB image with [0,1] scale and properly sized + inputs = self.processor(images=image, return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device) + # This resampling of positional embedding uses bicubic interpolation + outputs = self.forward_model(inputs) + last_hidden_states = outputs.last_hidden_state + return last_hidden_states + + def _freeze(self): + logger.warning(f"======== Freezing DinoWrapper ========") + self.model.eval() + for name, param in self.model.named_parameters(): + param.requires_grad = False + + @staticmethod + def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): + import requests + try: + model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) + processor = ViTImageProcessor.from_pretrained(model_name) + return model, processor + except requests.exceptions.ProxyError as err: + if proxy_error_retries > 0: + print(f"Huggingface ProxyError: Retrying ({proxy_error_retries}) in {proxy_error_cooldown} seconds...") + import time + time.sleep(proxy_error_cooldown) + return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) + else: + raise err diff --git a/openlrm/models/encoders/dinov2/__init__.py b/openlrm/models/encoders/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6 --- /dev/null +++ b/openlrm/models/encoders/dinov2/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Empty diff --git a/openlrm/models/encoders/dinov2/hub/__init__.py b/openlrm/models/encoders/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/openlrm/models/encoders/dinov2/hub/backbones.py b/openlrm/models/encoders/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd8c4010204da1f1e413db66d24a87e2a39a358 --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/backbones.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + # ********** Modified by Zexin He in 2023-2024 ********** + state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern + if vit_kwargs.get("modulation_dim") is not None: + state_dict = { + k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v + for k, v in state_dict.items() + } + model.load_state_dict(state_dict, strict=False) + else: + model.load_state_dict(state_dict, strict=True) + # ******************************************************** + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/openlrm/models/encoders/dinov2/hub/classifiers.py b/openlrm/models/encoders/dinov2/hub/classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0841efa80ab3d564cd320d61da254af182606b --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/classifiers.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch +import torch.nn as nn + +from .backbones import _make_dinov2_model +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + IMAGENET1K = "IMAGENET1K" + + +def _make_dinov2_linear_classification_head( + *, + arch_name: str = "vit_large", + patch_size: int = 14, + embed_dim: int = 1024, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) + + if pretrained: + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + layers_str = str(layers) if layers == 4 else "" + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + linear_head.load_state_dict(state_dict, strict=True) + + return linear_head + + +class _LinearClassifierWrapper(nn.Module): + def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): + super().__init__() + self.backbone = backbone + self.linear_head = linear_head + self.layers = layers + + def forward(self, x): + if self.layers == 1: + x = self.backbone.forward_features(x) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + # fmt: off + linear_input = torch.cat([ + cls_token, + patch_tokens.mean(dim=1), + ], dim=1) + # fmt: on + elif self.layers == 4: + x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) + # fmt: off + linear_input = torch.cat([ + x[0][1], + x[1][1], + x[2][1], + x[3][1], + x[3][0].mean(dim=1), + ], dim=1) + # fmt: on + else: + assert False, f"Unsupported number of layers: {self.layers}" + return self.linear_head(linear_input) + + +def _make_dinov2_linear_classifier( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + **kwargs, +): + backbone = _make_dinov2_model( + arch_name=arch_name, + pretrained=pretrained, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + **kwargs, + ) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + linear_head = _make_dinov2_linear_classification_head( + arch_name=arch_name, + patch_size=patch_size, + embed_dim=embed_dim, + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=num_register_tokens, + ) + + return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) + + +def dinov2_vits14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitb14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitl14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitg14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vits14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/openlrm/models/encoders/dinov2/hub/depth/__init__.py b/openlrm/models/encoders/dinov2/hub/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91716e58ab6158d814df8c653644d9af4c7be65c --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/depth/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .decode_heads import BNHead, DPTHead +from .encoder_decoder import DepthEncoderDecoder diff --git a/openlrm/models/encoders/dinov2/hub/depth/decode_heads.py b/openlrm/models/encoders/dinov2/hub/depth/decode_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..f455accad38fec6ecdd53460233a564c34f434da --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/depth/decode_heads.py @@ -0,0 +1,747 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from functools import partial +import math +import warnings + +import torch +import torch.nn as nn + +from .ops import resize + + +# XXX: (Untested) replacement for mmcv.imdenormalize() +def _imdenormalize(img, mean, std, to_bgr=True): + import numpy as np + + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = (img * std) + mean + if to_bgr: + img = img[::-1] + return img + + +class DepthBaseDecodeHead(nn.Module): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_layer (nn.Module): Conv layers. Default: None. + act_layer (nn.Module): Activation layers. Default: nn.ReLU. + loss_decode (dict): Config of decode loss. + Default: (). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_layer (dict|None): Norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + conv_layer=None, + act_layer=nn.ReLU, + channels=96, + loss_decode=(), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_layer=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conf_layer = conv_layer + self.act_layer = act_layer + self.loss_decode = loss_decode + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_layer = norm_layer + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + import numpy as np + + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = _imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} + + +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + return output + + +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_layer. Bias will be set as True if `norm_layer` is None, otherwise + False. Default: "auto". + conv_layer (nn.Module): Convolution layer. Default: None, + which means using conv2d. + norm_layer (nn.Module): Normalization layer. Default: None. + act_layer (nn.Module): Activation layer. Default: nn.ReLU. + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = "conv_block" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias="auto", + conv_layer=nn.Conv2d, + norm_layer=None, + act_layer=nn.ReLU, + inplace=True, + with_spectral_norm=False, + padding_mode="zeros", + order=("conv", "norm", "act"), + ): + super(ConvModule, self).__init__() + official_padding_mode = ["zeros", "circular"] + self.conv_layer = conv_layer + self.norm_layer = norm_layer + self.act_layer = act_layer + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(["conv", "norm", "act"]) + + self.with_norm = norm_layer is not None + self.with_activation = act_layer is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == "auto": + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + if padding_mode == "zeros": + padding_layer = nn.ZeroPad2d + else: + raise AssertionError(f"Unsupported padding mode: {padding_mode}") + self.pad = padding_layer(padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = self.conv_layer( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index("norm") > order.index("conv"): + norm_channels = out_channels + else: + norm_channels = in_channels + norm = partial(norm_layer, num_features=norm_channels) + self.add_module("norm", norm) + if self.with_bias: + from torch.nnModules.batchnorm import _BatchNorm + from torch.nnModules.instancenorm import _InstanceNorm + + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn("Unnecessary conv bias before batch/instance norm") + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + # nn.Tanh has no 'inplace' argument + # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU) + if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): + act_layer = partial(act_layer, inplace=inplace) + self.activate = act_layer() + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, "init_weights"): + if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): + nonlinearity = "leaky_relu" + a = 0.01 # XXX: default negative_slope + else: + nonlinearity = "relu" + a = 0 + if hasattr(self.conv, "weight") and self.conv.weight is not None: + nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) + if hasattr(self.conv, "bias") and self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + if self.with_norm: + if hasattr(self.norm, "weight") and self.norm.weight is not None: + nn.init.constant_(self.norm.weight, 1) + if hasattr(self.norm, "bias") and self.norm.bias is not None: + nn.init.constant_(self.norm.bias, 0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == "conv": + if self.with_explicit_padding: + x = self.pad(x) + x = self.conv(x) + elif layer == "norm" and norm and self.with_norm: + x = self.norm(x) + elif layer == "act" and activate and self.with_activation: + x = self.activate(x) + return x + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(nn.Module): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + """ + + def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): + super(ReassembleBlocks, self).__init__() + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_layer=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(nn.Module): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_layer (nn.Module): activation layer. + norm_layer (nn.Module): norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + """ + + def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): + super(PreActResidualConvUnit, self).__init__() + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(nn.Module): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_layer (nn.Module): activation layer for ResidualConvUnit. + norm_layer (nn.Module): normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + """ + + def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): + super(FeatureFusionBlock, self).__init__() + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs, + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/openlrm/models/encoders/dinov2/hub/depth/encoder_decoder.py b/openlrm/models/encoders/dinov2/hub/depth/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eb29ced67957a336e763b0e7c90c0eeaea36fea8 --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/depth/encoder_decoder.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ops import resize + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +class DepthEncoderDecoder(nn.Module): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone and decode_head. + """ + + def __init__(self, backbone, decode_head): + super(DepthEncoderDecoder, self).__init__() + + self.backbone = backbone + self.decode_head = decode_head + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + return self.backbone(img) + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + return self.encode_decode(img, img_meta, rescale, size=size) + + def slide_inference(self, img, img_meta, rescale, stride, crop_size): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = stride + h_crop, w_crop = crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None, mode="whole"): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + import torch.distributed as dist + + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/openlrm/models/encoders/dinov2/hub/depth/ops.py b/openlrm/models/encoders/dinov2/hub/depth/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/depth/ops.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/openlrm/models/encoders/dinov2/hub/depthers.py b/openlrm/models/encoders/dinov2/hub/depthers.py new file mode 100644 index 0000000000000000000000000000000000000000..f88b7e9a41056594e3b3e66107feee98bffab820 --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/depthers.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from functools import partial +from typing import Optional, Tuple, Union + +import torch + +from .backbones import _make_dinov2_model +from .depth import BNHead, DepthEncoderDecoder, DPTHead +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding + + +class Weights(Enum): + NYU = "NYU" + KITTI = "KITTI" + + +def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: + if not pretrained: # Default + return (0.001, 10.0) + + # Pretrained, set according to the training dataset for the provided weights + if weights == Weights.KITTI: + return (0.001, 80.0) + + if weights == Weights.NYU: + return (0.001, 10.0) + + return (0.001, 10.0) + + +def _make_dinov2_linear_depth_head( + *, + embed_dim: int, + layers: int, + min_depth: float, + max_depth: float, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + + if layers == 1: + in_index = [0] + else: + assert layers == 4 + in_index = [0, 1, 2, 3] + + return BNHead( + classify=True, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + upsample=4, + in_channels=[embed_dim] * len(in_index), + in_index=in_index, + input_transform="resize_concat", + channels=embed_dim * len(in_index) * 2, + align_corners=False, + min_depth=0.001, + max_depth=80, + loss_decode=(), + ) + + +def _make_dinov2_linear_depther( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + model_name = _make_dinov2_model_name(arch_name, patch_size) + linear_depth_head = _make_dinov2_linear_depth_head( + embed_dim=embed_dim, + layers=layers, + min_depth=min_depth, + max_depth=max_depth, + ) + + layer_count = { + "vit_small": 12, + "vit_base": 12, + "vit_large": 24, + "vit_giant2": 40, + }[arch_name] + + if layers == 4: + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + else: + assert layers == 1 + out_index = [layer_count - 1] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) + + if pretrained: + layers_str = str(layers) if layers == 4 else "" + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) + + +def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): + return DPTHead( + in_channels=[embed_dim] * 4, + channels=256, + embed_dims=embed_dim, + post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], + readout_type="project", + min_depth=min_depth, + max_depth=max_depth, + loss_decode=(), + ) + + +def _make_dinov2_dpt_depther( + *, + arch_name: str = "vit_large", + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) + dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) + + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) + + if pretrained: + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther( + arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) diff --git a/openlrm/models/encoders/dinov2/hub/utils.py b/openlrm/models/encoders/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/openlrm/models/encoders/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/openlrm/models/encoders/dinov2/layers/__init__.py b/openlrm/models/encoders/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77967aa6ccfae24c39b8e167c83dd77073fd68fb --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# ****************************************************************************** +# Code modified by Zexin He in 2023-2024. +# Modifications are marked with clearly visible comments +# licensed under the Apache License, Version 2.0. +# ****************************************************************************** + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +# ********** Modified by Zexin He in 2023-2024 ********** +# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock +from .block import Block, BlockWithModulation +# ******************************************************** +from .attention import MemEffAttention diff --git a/openlrm/models/encoders/dinov2/layers/attention.py b/openlrm/models/encoders/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb76ef2816164729a58cceb18d0f000cfb18777 --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/openlrm/models/encoders/dinov2/layers/block.py b/openlrm/models/encoders/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5b50118c1579fd30cda0c2d60b95c85eb04204 --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/block.py @@ -0,0 +1,296 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +# ****************************************************************************** +# Code modified by Zexin He in 2023-2024. +# Modifications are marked with clearly visible comments +# licensed under the Apache License, Version 2.0. +# ****************************************************************************** + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +# ********** Modified by Zexin He in 2023-2024 ********** +# Override forward with modulation input +class BlockWithModulation(Block): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, mod: Tensor) -> Tensor: + def attn_residual_func(x: Tensor, mod: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x, mod))) + + def ffn_residual_func(x: Tensor, mod: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x, mod))) + + if self.training and self.sample_drop_ratio > 0.1: + raise NotImplementedError("Modulation with drop path ratio larger than 0.1 is not supported yet") + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, mod)) + x = x + self.drop_path1(ffn_residual_func(x, mod)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, mod) + x = x + ffn_residual_func(x, mod) + return x +# ******************************************************** + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + + # ********** Modified by Zexin He in 2023-2024 ********** + warnings.warn("NestedTensorBlock is deprecated for now!", DeprecationWarning) + # ******************************************************** + + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/openlrm/models/encoders/dinov2/layers/dino_head.py b/openlrm/models/encoders/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/openlrm/models/encoders/dinov2/layers/drop_path.py b/openlrm/models/encoders/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/openlrm/models/encoders/dinov2/layers/layer_scale.py b/openlrm/models/encoders/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/openlrm/models/encoders/dinov2/layers/mlp.py b/openlrm/models/encoders/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/openlrm/models/encoders/dinov2/layers/patch_embed.py b/openlrm/models/encoders/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/openlrm/models/encoders/dinov2/layers/swiglu_ffn.py b/openlrm/models/encoders/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74 --- /dev/null +++ b/openlrm/models/encoders/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/openlrm/models/encoders/dinov2/models/__init__.py b/openlrm/models/encoders/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/openlrm/models/encoders/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/openlrm/models/encoders/dinov2/models/vision_transformer.py b/openlrm/models/encoders/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c90ac2be1fe294a0db6080cd24155629083d3ec9 --- /dev/null +++ b/openlrm/models/encoders/dinov2/models/vision_transformer.py @@ -0,0 +1,443 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +# ****************************************************************************** +# Code modified by Zexin He in 2023-2024. +# Modifications are marked with clearly visible comments +# licensed under the Apache License, Version 2.0. +# ****************************************************************************** + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +# ********** Modified by Zexin He in 2023-2024 ********** +# Avoid using nested tensor for now, deprecating usage of NestedTensorBlock +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, Block, BlockWithModulation +# ******************************************************** + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + # ********** Modified by Zexin He in 2023-2024 ********** + modulation_dim: int = None, + # ******************************************************** + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + + # ********** Modified by Zexin He in 2023-2024 ********** + block_norm_layer = None + if modulation_dim is not None: + from ....modulate import ModLN + block_norm_layer = partial(ModLN, mod_dim=modulation_dim) + else: + block_norm_layer = nn.LayerNorm + block_norm_layer = partial(block_norm_layer, eps=1e-6) + # ******************************************************** + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + # ********** Modified by Zexin He in 2023-2024 ********** + norm_layer=block_norm_layer, + # ******************************************************** + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + # ********** Modified by Zexin He in 2023-2024 ********** + # hacking unused mask_token for better DDP + # self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + # ******************************************************** + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + # ********** Modified by Zexin He in 2023-2024 ********** + raise NotImplementedError("Masking is not supported in hacked DINOv2") + # x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + # ******************************************************** + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + # ********** Modified by Zexin He in 2023-2024 ********** + def forward_features(self, x, masks=None, mod=None): + if isinstance(x, list): + raise DeprecationWarning("forward_features_list is deprecated, use forward_features") + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + if mod is None: + for blk in self.blocks: + x = blk(x) + else: + for blk in self.blocks: + x = blk(x, mod) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + # ******************************************************** + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +# ********** Modified by Zexin He in 2023-2024 ********** +# block class selected from Block and BlockWithModulation + +def _block_cls(**kwargs): + modulation_dim = kwargs.get("modulation_dim", None) + if modulation_dim is None: + block_cls = Block + else: + block_cls = BlockWithModulation + return block_cls + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(_block_cls(**kwargs), attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + +# ******************************************************** diff --git a/openlrm/models/encoders/dinov2_wrapper.py b/openlrm/models/encoders/dinov2_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb463ed812c4b92b20d7c49b0219eecd607c607 --- /dev/null +++ b/openlrm/models/encoders/dinov2_wrapper.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from accelerate.logging import get_logger + + +logger = get_logger(__name__) + + +class Dinov2Wrapper(nn.Module): + """ + Dino v2 wrapper using original implementation, hacked with modulation. + """ + def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True): + super().__init__() + self.modulation_dim = modulation_dim + self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim) + if freeze: + if modulation_dim is not None: + raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.") + self._freeze() + + def _freeze(self): + logger.warning(f"======== Freezing Dinov2Wrapper ========") + self.model.eval() + for name, param in self.model.named_parameters(): + param.requires_grad = False + + @staticmethod + def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True): + from importlib import import_module + dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__) + model_fn = getattr(dinov2_hub, model_name) + logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.") + model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained) + return model + + @torch.compile + def forward(self, image: torch.Tensor, mod: torch.Tensor = None): + # image: [N, C, H, W] + # mod: [N, D] or None + # RGB image with [0,1] scale and properly sized + if self.modulation_dim is None: + assert mod is None, "Unexpected modulation input in dinov2 forward." + outs = self.model(image, is_training=True) + else: + assert mod is not None, "Modulation input is required in modulated dinov2 forward." + outs = self.model(image, mod=mod, is_training=True) + ret = torch.cat([ + outs["x_norm_clstoken"].unsqueeze(dim=1), + outs["x_norm_patchtokens"], + ], dim=1) + return ret diff --git a/openlrm/models/modeling_lrm.py b/openlrm/models/modeling_lrm.py new file mode 100644 index 0000000000000000000000000000000000000000..f865937cb7cf51f692297b726075af36bbcaab3e --- /dev/null +++ b/openlrm/models/modeling_lrm.py @@ -0,0 +1,245 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from accelerate.logging import get_logger + +from .embedder import CameraEmbedder +from .transformer import TransformerDecoder +from .rendering.synthesizer import TriplaneSynthesizer +from .utils import zero_module +import loratorch as lora +from .swin_transformer import CrossAttentionLayer + +logger = get_logger(__name__) + + +class ModelLRM(nn.Module): + """ + Full model of the basic single-view large reconstruction model. + """ + def __init__(self, camera_embed_dim: int, rendering_samples_per_ray: int, + transformer_dim: int, transformer_layers: int, transformer_heads: int, + triplane_low_res: int, triplane_high_res: int, triplane_dim: int, + encoder_freeze: bool = True, encoder_type: str = 'dino', + encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768, + model_lora_rank: int = 0, conv_fuse=False, + swin_ca_fuse=False, ca_dim=32, ca_depth=2, ca_num_heads=8, ca_window_size=2): + super().__init__() + + # attributes + self.encoder_feat_dim = encoder_feat_dim + self.camera_embed_dim = camera_embed_dim + self.triplane_low_res = triplane_low_res + self.triplane_high_res = triplane_high_res + self.triplane_dim = triplane_dim + + self.conv_fuse = conv_fuse + self.swin_ca_fuse = swin_ca_fuse + + # modules + self.encoder = self._encoder_fn(encoder_type)( + model_name=encoder_model_name, + freeze=encoder_freeze, + ) + self.camera_embedder = CameraEmbedder( + raw_dim=12+4, embed_dim=camera_embed_dim, + ) + # initialize pos_embed with 1/sqrt(dim) * N(0, 1) + self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5) + if model_lora_rank > 0: + self.transformer = TransformerDecoder( + block_type='cond_mod', + num_layers=transformer_layers, num_heads=transformer_heads, + inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=camera_embed_dim, + lora_rank=model_lora_rank + ) + lora.mark_only_lora_as_trainable(self.transformer) + else: + self.transformer = TransformerDecoder( + block_type='cond_mod', + num_layers=transformer_layers, num_heads=transformer_heads, + inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=camera_embed_dim, + ) + self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=2, stride=2, padding=0) + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, samples_per_ray=rendering_samples_per_ray, + ) + + if model_lora_rank > 0: + if self.conv_fuse: + # self.front_back_conv = nn.Conv2d(in_channels=triplane_dim*2, out_channels=triplane_dim, kernel_size=(3, 3), stride=(1, 1), padding=1) + # zero_module(self.front_back_conv) + self.front_back_conv = nn.ModuleList([ + nn.Conv2d(in_channels=triplane_dim*2, out_channels=triplane_dim*4, kernel_size=(3, 3), stride=(1, 1), padding=1), + nn.LayerNorm([triplane_dim*4, triplane_high_res, triplane_high_res]), # Using Layer Normalization + nn.GELU(), # Using GELU activation + nn.Conv2d(in_channels=triplane_dim*4, out_channels=triplane_dim*4, kernel_size=(3, 3), stride=(1, 1), padding=1), + nn.LayerNorm([triplane_dim*4, triplane_high_res, triplane_high_res]), # Using Layer Normalization + nn.GELU(), # Using GELU activation + nn.Conv2d(in_channels=triplane_dim*4, out_channels=triplane_dim, kernel_size=(3, 3), stride=(1, 1), padding=1) + ]) + self.freeze_modules(encoder=True, camera_embedder=True, + pos_embed=False, transformer=False, upsampler=False, + synthesizer=False) + elif self.swin_ca_fuse: + self.swin_cross_attention = CrossAttentionLayer(dim=ca_dim, depth=ca_depth, num_heads=ca_num_heads, window_size=ca_window_size) + self.freeze_modules(encoder=True, camera_embedder=True, + pos_embed=False, transformer=False, upsampler=False, + synthesizer=False) + else: + raise ValueError("You need to specify a method for fusing the front and the back.") + + + def freeze_modules(self, encoder=False, camera_embedder=False, + pos_embed=False, transformer=False, upsampler=False, + synthesizer=False): + """ + Freeze specified modules + """ + if encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + if camera_embedder: + for param in self.camera_embedder.parameters(): + param.requires_grad = False + if pos_embed: + for param in self.pos_embed.parameters(): + param.requires_grad = False + if transformer: + for param in self.transformer.parameters(): + param.requires_grad = False + if upsampler: + for param in self.upsampler.parameters(): + param.requires_grad = False + if synthesizer: + for param in self.synthesizer.parameters(): + param.requires_grad = False + + @staticmethod + def _encoder_fn(encoder_type: str): + encoder_type = encoder_type.lower() + assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type" + if encoder_type == 'dino': + from .encoders.dino_wrapper import DinoWrapper + logger.info("Using DINO as the encoder") + return DinoWrapper + elif encoder_type == 'dinov2': + from .encoders.dinov2_wrapper import Dinov2Wrapper + logger.info("Using DINOv2 as the encoder") + return Dinov2Wrapper + + def forward_transformer(self, image_feats, camera_embeddings): + assert image_feats.shape[0] == camera_embeddings.shape[0], \ + "Batch size mismatch for image_feats and camera_embeddings!" + N = image_feats.shape[0] + x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] + x = self.transformer( + x, + cond=image_feats, + mod=camera_embeddings, + ) + return x + + def reshape_upsample(self, tokens): + N = tokens.shape[0] + H = W = self.triplane_low_res + x = tokens.view(N, 3, H, W, -1) + x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] + x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] + x = self.upsampler(x) # [3*N, D', H', W'] + x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] + x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = x.contiguous() + return x + + @torch.compile + def forward_planes(self, image, camera): + # image: [N, C_img, H_img, W_img] + # camera: [N, D_cam_raw] + N = image.shape[0] + + # encode image + image_feats = self.encoder(image) + assert image_feats.shape[-1] == self.encoder_feat_dim, \ + f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}" + + # embed camera + camera_embeddings = self.camera_embedder(camera) + assert camera_embeddings.shape[-1] == self.camera_embed_dim, \ + f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}" + + # transformer generating planes + tokens = self.forward_transformer(image_feats, camera_embeddings) + planes = self.reshape_upsample(tokens) + assert planes.shape[0] == N, "Batch size mismatch for planes" + assert planes.shape[1] == 3, "Planes should have 3 channels" + + return planes + + def forward(self, image, source_camera, render_cameras, render_anchors, render_resolutions, render_bg_colors, render_region_size: int, + image_back=None,): + # image: [N, C_img, H_img, W_img] + # source_camera: [N, D_cam_raw] + # render_cameras: [N, M, D_cam_render] + # render_anchors: [N, M, 2] + # render_resolutions: [N, M, 1] + # render_bg_colors: [N, M, 1] + # render_region_size: int + assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera" + assert image.shape[0] == render_cameras.shape[0], "Batch size mismatch for image and render_cameras" + assert image.shape[0] == render_anchors.shape[0], "Batch size mismatch for image and render_anchors" + assert image.shape[0] == render_bg_colors.shape[0], "Batch size mismatch for image and render_bg_colors" + N, M = render_cameras.shape[:2] + + if image_back is not None: + front_planes = self.forward_planes(image, source_camera) + back_planes = self.forward_planes(image_back, source_camera) + + # XY Plane + back_planes[:, 0, :, :, :] = torch.flip(back_planes[:, 0, :, :, :], dims=[-2, -1]) + # XZ Plane + back_planes[:, 1, :, :, :] = torch.flip(back_planes[:, 1, :, :, :], dims=[-1]) + # YZ Plane + back_planes[:, 2, :, :, :] = torch.flip(back_planes[:, 2, :, :, :], dims=[-2]) + + # To fuse the front planes and the back planes + bs, num_planes, channels, height, width = front_planes.shape + if self.conv_fuse: + planes = torch.cat((front_planes, back_planes), dim=2) + planes = planes.reshape(-1, channels*2, height, width) + # Apply multiple convolutional layers + for layer in self.front_back_conv: + planes = layer(planes) + + planes = planes.view(bs, num_planes, -1, height, width) + # planes = self.front_back_conv(planes).view(bs, num_planes, -1, height, width) # only one layer. + elif self.swin_ca_fuse: + front_planes = front_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() # [8, 3, 32, 64, 64] -> [24, 32, 4096] -> [24, 4096, 32] + back_planes = back_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() + planes = self.swin_cross_attention(front_planes, back_planes, height, width)[0].permute(0, 2, 1).reshape(bs, num_planes, channels, height, width) + else: + planes = self.forward_planes(image, source_camera) + + # render target views + render_results = self.synthesizer(planes, render_cameras, render_anchors, render_resolutions, render_bg_colors, render_region_size) + assert render_results['images_rgb'].shape[0] == N, "Batch size mismatch for render_results" + assert render_results['images_rgb'].shape[1] == M, "Number of rendered views should be consistent with render_cameras" + + return { + 'planes': planes, + **render_results, + } diff --git a/openlrm/models/modulate.py b/openlrm/models/modulate.py new file mode 100644 index 0000000000000000000000000000000000000000..8d2a0f0240cc1d596a9a544d56eac5ee7e03cc7d --- /dev/null +++ b/openlrm/models/modulate.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class ModLN(nn.Module): + """ + Modulation with adaLN. + + References: + DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101 + """ + def __init__(self, inner_dim: int, mod_dim: int, eps: float): + super().__init__() + self.norm = nn.LayerNorm(inner_dim, eps=eps) + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(mod_dim, inner_dim * 2), + ) + + @staticmethod + def modulate(x, shift, scale): + # x: [N, L, D] + # shift, scale: [N, D] + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D] + return self.modulate(self.norm(x), shift, scale) # [N, L, D] diff --git a/openlrm/models/rendering/__init__.py b/openlrm/models/rendering/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6 --- /dev/null +++ b/openlrm/models/rendering/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Empty diff --git a/openlrm/models/rendering/synthesizer.py b/openlrm/models/rendering/synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3833cfeb681e1a8f61d8244a36d03c6631d8b3de --- /dev/null +++ b/openlrm/models/rendering/synthesizer.py @@ -0,0 +1,208 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Zexin He in 2023-2024. +# The modifications are subject to the same license as the original. + + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import ImportanceRenderer +from .utils.ray_sampler import RaySampler + + +class ShiftedSoftplus(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return nn.functional.softplus(x - 1) + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + self.net = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1 + 3), + ) + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + @torch.compile + def forward(self, sampled_features, ray_directions): + # Aggregate features by mean + # sampled_features = sampled_features.mean(1) + # Aggregate features by concatenation + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + x = sampled_features + + N, M, C = x.shape + x = x.contiguous().view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma} + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': False, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # renderings + self.renderer = ImportanceRenderer() + self.ray_sampler = RaySampler() + + # modules + self.decoder = OSGDecoder(n_features=triplane_dim) + + def forward(self, planes, cameras, anchors, resolutions, bg_colors, region_size: int): + # planes: (N, 3, D', H', W') + # cameras: (N, M, D_cam) + # anchors: (N, M, 2) + # resolutions: (N, M, 1) + # bg_colors: (N, M, 1) + # region_size: int + assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" + assert planes.shape[0] == anchors.shape[0], "Batch size mismatch for planes and anchors" + assert cameras.shape[1] == anchors.shape[1], "Number of views mismatch for cameras and anchors" + N, M = cameras.shape[:2] + + cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) + intrinsics = cameras[..., 16:25].view(N, M, 3, 3) + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler( + cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), + intrinsics=intrinsics.reshape(-1, 3, 3), + resolutions=resolutions.reshape(-1, 1), + anchors=anchors.reshape(-1, 2), + region_size=region_size, + ) + assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins" + assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" + + # Perform volume rendering + rgb_samples, depth_samples, weights_samples = self.renderer( + planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs, + bg_colors=bg_colors.reshape(-1, 1), + ) + + # Reshape into 'raw' neural-rendered image + Himg = Wimg = region_size + rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() + depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + + return { + 'images_rgb': rgb_images, + 'images_depth': depth_images, + 'images_weight': weight_images, + } + + def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): + # planes: (N, 3, D', H', W') + # grid_size: int + # aabb: (N, 2, 3) + if aabb is None: + aabb = torch.tensor([ + [self.rendering_kwargs['sampler_bbox_min']] * 3, + [self.rendering_kwargs['sampler_bbox_max']] * 3, + ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) + assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" + N = planes.shape[0] + + # create grid points for triplane query + grid_points = [] + for i in range(N): + grid_points.append(torch.stack(torch.meshgrid( + torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), + indexing='ij', + ), dim=-1).reshape(-1, 3)) + cube_grid = torch.stack(grid_points, dim=0).to(planes.device) + + features = self.forward_points(planes, cube_grid) + + # reshape into grid + features = { + k: v.reshape(N, grid_size, grid_size, grid_size, -1) + for k, v in features.items() + } + return features + + def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): + # planes: (N, 3, D', H', W') + # points: (N, P, 3) + N, P = points.shape[:2] + + # query triplane in chunks + outs = [] + for i in range(0, points.shape[1], chunk_size): + chunk_points = points[:, i:i+chunk_size] + + # query triplane + chunk_out = self.renderer.run_model_activated( + planes=planes, + decoder=self.decoder, + sample_coordinates=chunk_points, + sample_directions=torch.zeros_like(chunk_points), + options=self.rendering_kwargs, + ) + outs.append(chunk_out) + + # concatenate the outputs + point_features = { + k: torch.cat([out[k] for out in outs], dim=1) + for k in outs[0].keys() + } + return point_features diff --git a/openlrm/models/rendering/utils/__init__.py b/openlrm/models/rendering/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/openlrm/models/rendering/utils/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/openlrm/models/rendering/utils/math_utils.py b/openlrm/models/rendering/utils/math_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf9d2b811e0acbc7923bc9126e010b52cb1a8af --- /dev/null +++ b/openlrm/models/rendering/utils/math_utils.py @@ -0,0 +1,118 @@ +# MIT License + +# Copyright (c) 2022 Petr Kellnhofer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/openlrm/models/rendering/utils/ray_marcher.py b/openlrm/models/rendering/utils/ray_marcher.py new file mode 100644 index 0000000000000000000000000000000000000000..8c686c196e043f44e2276f16b4a32e596c802e40 --- /dev/null +++ b/openlrm/models/rendering/utils/ray_marcher.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Zexin He in 2023-2024. +# The modifications are subject to the same license as the original. + + +""" +The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" + +import torch +import torch.nn as nn + + +class MipRayMarcher2(nn.Module): + def __init__(self, activation_factory): + super().__init__() + self.activation_factory = activation_factory + + def run_forward(self, colors, densities, depths, rendering_options, bg_colors=None): + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + # using factory mode for better usability + densities_mid = self.activation_factory(rendering_options)(densities_mid) + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + else: + assert bg_colors is not None, "Must provide bg_colors if white_back is False" + composite_rgb = composite_rgb + bg_colors.unsqueeze(-1) * (1 - weight_total) + + # rendered value scale is 0-1, comment out original mipnerf scaling + # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, weights + + + def forward(self, colors, densities, depths, rendering_options, bg_colors=None): + composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options, bg_colors=bg_colors) + + return composite_rgb, composite_depth, weights diff --git a/openlrm/models/rendering/utils/ray_sampler.py b/openlrm/models/rendering/utils/ray_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab9594be66d02df79ec2295dbd064906f748c2c --- /dev/null +++ b/openlrm/models/rendering/utils/ray_sampler.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Zexin He in 2023-2024. +# The modifications are subject to the same license as the original. + + +""" +The ray sampler is a module that takes in camera matrices and resolution and batches of rays. +Expects cam2world matrices that use the OpenCV camera coordinate system conventions. +""" + +import torch + +class RaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + @torch.compile + def forward(self, cam2world_matrix, intrinsics, resolutions, anchors, region_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + resolutions: (N, 1) + anchors: (N, 2) + region_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + + N, M = cam2world_matrix.shape[0], region_size**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack(torch.meshgrid( + torch.arange(region_size, dtype=torch.float32, device=cam2world_matrix.device), + torch.arange(region_size, dtype=torch.float32, device=cam2world_matrix.device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + # anchors are indexed as normal (row, col) but uv is indexed as (x, y) + x_cam = (uv[:, :, 0].view(N, -1) + anchors[:, 1].unsqueeze(-1)) * (1./resolutions) + (0.5/resolutions) + y_cam = (uv[:, :, 1].view(N, -1) + anchors[:, 0].unsqueeze(-1)) * (1./resolutions) + (0.5/resolutions) + z_cam = torch.ones((N, M), device=cam2world_matrix.device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs diff --git a/openlrm/models/rendering/utils/renderer.py b/openlrm/models/rendering/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..628fc029877f8e069feb20d3b310ef4692b4a4bc --- /dev/null +++ b/openlrm/models/rendering/utils/renderer.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Zexin He in 2023-2024. +# The modifications are subject to the same license as the original. + + +""" +The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ray_marcher import MipRayMarcher2 +from . import math_utils + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + + Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections[..., :2] + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + + coordinates = (2/box_warp) * coordinates # add specific box bounds + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', padding_mode='zeros', align_corners=False) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +class ImportanceRenderer(torch.nn.Module): + """ + Modified original version to filter out-of-box samples as TensoRF does. + + Reference: + TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 + """ + def __init__(self): + super().__init__() + self.activation_factory = self._build_activation_factory() + self.ray_marcher = MipRayMarcher2(self.activation_factory) + self.plane_axes = generate_planes() + + def _build_activation_factory(self): + def activation_factory(options: dict): + if options['clamp_mode'] == 'softplus': + return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better + else: + assert False, "Renderer only supports `clamp_mode`=`softplus`!" + return activation_factory + + def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, + planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): + """ + Additional filtering is applied to filter out-of-box samples. + Modifications made by Zexin He. + """ + + # context related variables + batch_size, num_rays, samples_per_ray, _ = depths.shape + device = depths.device + + # define sample points with depths + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + + # filter out-of-box samples + mask_inbox = \ + (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ + (sample_coordinates <= rendering_options['sampler_bbox_max']) + mask_inbox = mask_inbox.all(-1) + + # forward model according to all samples + _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + + # set out-of-box samples to zeros(rgb) & -inf(sigma) + SAFE_GUARD = 8 + DATA_TYPE = _out['sigma'].dtype + colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) + densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD + colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox] + + # reshape back + colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) + densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) + + return colors_pass, densities_pass + + def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, bg_colors=None): + # self.plane_axes = self.plane_axes.to(ray_origins.device) + + if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + else: + # Create stratified depth samples + depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + # Coarse Pass + colors_coarse, densities_coarse = self._forward_pass( + depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, bg_colors=bg_colors) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + colors_fine, densities_fine = self._forward_pass( + depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + depths_fine, colors_fine, densities_fine) + + # Aggregate + rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, bg_colors=bg_colors) + else: + rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, bg_colors=bg_colors) + + return rgb_final, depth_final, weights.sum(2) + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) + + out = decoder(sampled_features, sample_directions) + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + return out + + def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) + out['sigma'] = self.activation_factory(options)(out['sigma']) + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + return all_depths, all_colors, all_densities + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom [B, H//Mh, W//Mh, Mw, Mw, C] + # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C] + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Restore each window into a feature map. + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size(M) + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C] + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C] + # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C] + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class WindowCrossAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # [Mh, Mw] + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH] + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw] + coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw] + # [2, Mh*Mw, 1] - [2, 1, Mh*Mw] + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2] + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw] + self.register_buffer("relative_position_index", relative_position_index) + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, kv, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, Mh*Mw, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # [batch_size*num_windows, Mh*Mw, total_embed_dim] + B_, N, C = x.shape + # q(): -> [batch_size*num_windows, Mh*Mw, 1*total_embed_dim] + # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head] + # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] + q = self.q(x).reshape(B_, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] + + kv = self.kv(kv).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw] + # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw] + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + # mask: [nW, Mh*Mw, Mh*Mw] + nW = mask.shape[0] # num_windows + # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw] + # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] + # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head] + # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim] + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerCABlock(nn.Module): + r""" Swin Transformer Cross Attention Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowCrossAttention( + dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, kv, attn_mask): + H, W = self.H, self.W + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + kv = self.norm1(kv) + kv = kv.view(B, H, W, C) + + # pad feature maps to multiples of window size + # Pad the feature map to multiples of the window size. + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + kv = F.pad(kv, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_kv = torch.roll(kv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + shifted_kv = kv + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C] + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C] + + kv_windows = window_partition(shifted_kv, self.window_size) # [nW*B, Mh, Mw, C] + kv_windows = kv_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C] + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, kv_windows, mask=attn_mask) # [nW*B, Mh*Mw, C] + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C] + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C] + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + # Remove the padded data from the front. + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class CrossAttentionLayer(nn.Module): + def __init__(self, dim, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm,): + super().__init__() + self.dim = dim + self.depth = depth + self.window_size = window_size + self.shift_size = window_size // 2 + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerCABlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + def create_mask(self, x, H, W): + # calculate attention mask for SW-MSA + # Ensure that Hp and Wp are multiples of window_size. + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + # Have the same channel arrangement as the feature map for ease of subsequent window_partition. + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1] + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1] + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1] + # [nW, Mh*Mw, Mh*Mw] + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + + def forward(self, x, kv, H, W): + attn_mask = self.create_mask(x, H, W) # [nW, Mh*Mw, Mh*Mw] + for blk in self.blocks: + blk.H, blk.W = H, W + x = blk(x, kv, attn_mask) + return x, H, W + + + +if __name__ == '__main__': + + shape = [8, 3, 32, 64, 64] + tensor = torch.zeros(shape) + _, _, _, H, W = tensor.shape + front_plane = tensor.reshape(-1, 32, 64*64).permute(0, 2,1).contiguous() + + back_plane = torch.zeros(front_plane.shape) + + model = CrossAttentionLayer( + dim=32, + depth=2, + num_heads=8, + window_size=2, + ) + + output = model(front_plane, back_plane, H, W) + diff --git a/openlrm/models/transformer.py b/openlrm/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..09217957ca643bb07283534ac58139360bda3524 --- /dev/null +++ b/openlrm/models/transformer.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +import torch +import torch.nn as nn +from accelerate.logging import get_logger + + +logger = get_logger(__name__) + + +class TransformerDecoder(nn.Module): + + """ + Transformer blocks that process the input and optionally use condition and modulation. + """ + + def __init__(self, block_type: str, + num_layers: int, num_heads: int, + inner_dim: int, cond_dim: int = None, mod_dim: int = None, + eps: float = 1e-6, + lora_rank: int = 0): + super().__init__() + self.block_type = block_type + self.layers = nn.ModuleList([ + self._block_fn(inner_dim, cond_dim, mod_dim, lora_rank=lora_rank)( + num_heads=num_heads, + eps=eps, + ) + for _ in range(num_layers) + ]) + self.norm = nn.LayerNorm(inner_dim, eps=eps) + + @property + def block_type(self): + return self._block_type + + @block_type.setter + def block_type(self, block_type): + assert block_type in ['basic', 'cond', 'mod', 'cond_mod'], \ + f"Unsupported block type: {block_type}" + self._block_type = block_type + + def _block_fn(self, inner_dim, cond_dim, mod_dim, lora_rank=0): + assert inner_dim is not None, f"inner_dim must always be specified" + if self.block_type == 'basic': + assert cond_dim is None and mod_dim is None, \ + f"Condition and modulation are not supported for BasicBlock" + from .block import BasicBlock + logger.debug(f"Using BasicBlock") + return partial(BasicBlock, inner_dim=inner_dim) + elif self.block_type == 'cond': + assert cond_dim is not None, f"Condition dimension must be specified for ConditionBlock" + assert mod_dim is None, f"Modulation dimension is not supported for ConditionBlock" + from .block import ConditionBlock + logger.debug(f"Using ConditionBlock") + return partial(ConditionBlock, inner_dim=inner_dim, cond_dim=cond_dim) + elif self.block_type == 'mod': + logger.error(f"modulation without condition is not implemented") + raise NotImplementedError(f"modulation without condition is not implemented") + elif self.block_type == 'cond_mod': + assert cond_dim is not None and mod_dim is not None, \ + f"Condition and modulation dimensions must be specified for ConditionModulationBlock" + from .block import ConditionModulationBlock + logger.debug(f"Using ConditionModulationBlock") + return partial(ConditionModulationBlock, inner_dim=inner_dim, cond_dim=cond_dim, mod_dim=mod_dim, lora_rank=lora_rank) + else: + raise ValueError(f"Unsupported block type during runtime: {self.block_type}") + + def assert_runtime_integrity(self, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor): + assert x is not None, f"Input tensor must be specified" + if self.block_type == 'basic': + assert cond is None and mod is None, \ + f"Condition and modulation are not supported for BasicBlock" + elif self.block_type == 'cond': + assert cond is not None and mod is None, \ + f"Condition must be specified and modulation is not supported for ConditionBlock" + elif self.block_type == 'mod': + raise NotImplementedError(f"modulation without condition is not implemented") + else: + assert cond is not None and mod is not None, \ + f"Condition and modulation must be specified for ConditionModulationBlock" + + def forward_layer(self, layer: nn.Module, x: torch.Tensor, cond: torch.Tensor, mod: torch.Tensor): + if self.block_type == 'basic': + return layer(x) + elif self.block_type == 'cond': + return layer(x, cond) + elif self.block_type == 'mod': + return layer(x, mod) + else: + return layer(x, cond, mod) + + def forward(self, x: torch.Tensor, cond: torch.Tensor = None, mod: torch.Tensor = None): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] or None + # mod: [N, D_mod] or None + self.assert_runtime_integrity(x, cond, mod) + for layer in self.layers: + x = self.forward_layer(layer, x, cond, mod) + x = self.norm(x) + return x diff --git a/openlrm/models/utils.py b/openlrm/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..682f75a42e35110f9847fae849240662dc1f08e6 --- /dev/null +++ b/openlrm/models/utils.py @@ -0,0 +1,54 @@ +''' +This is to save and load the model. +''' + +def check_model_checkpoint_consistency(ckpt_state_dict, model_state_dict, special_strs=None): + """ + Maintain all checkpoint keys. Ignore keys with specific endings if absent. + Raise exception for model keys not in checkpoint unless ignored. + ckpt: The state dictionary of the checkpoint. + model_state_dict: The state dictionary of the model. + special_endings: A list of specific endings of strings to be ignored. + """ + filtered_ckpt = {} + special_modules =[] + for key in model_state_dict.keys(): + if key in ckpt_state_dict: + filtered_ckpt[key] = ckpt_state_dict[key] + elif any(special_str in key for special_str in special_strs): + special_modules.append(key) + continue + else: + raise KeyError(f"Key '{key}' not found in checkpoint and does not match any special endings.") + +def remove_module_prefix(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith('module.'): + new_key = key[len('module.'):] + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + return new_state_dict + + + +# This is for reducing impact at the beginning of training. +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def filter_model_checkpoint(ckpt_state_dict, model_state_dict, need_strs=None): + filtered_ckpt = {} + for key in model_state_dict.keys(): + if key in ckpt_state_dict and any(need_str in key for need_str in need_strs): + filtered_ckpt[key] = ckpt_state_dict[key] + else: + continue + + return filtered_ckpt diff --git a/openlrm/runners/__init__.py b/openlrm/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b01c77e2357b2bdec2bb1bc9866d54d8d5e58ed --- /dev/null +++ b/openlrm/runners/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from openlrm.utils.registry import Registry + +REGISTRY_RUNNERS = Registry() + +from .train import * +from .infer import * diff --git a/openlrm/runners/abstract.py b/openlrm/runners/abstract.py new file mode 100644 index 0000000000000000000000000000000000000000..76916e805a5cfbf333d2d63e8607811939a5a639 --- /dev/null +++ b/openlrm/runners/abstract.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + + +class Runner(ABC): + """Abstract runner class""" + + def __init__(self): + pass + + @abstractmethod + def run(self): + pass diff --git a/openlrm/runners/infer/__init__.py b/openlrm/runners/infer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf897a0d2a887cdd53f0c67016f7ab197a3788b --- /dev/null +++ b/openlrm/runners/infer/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .lrm import LRMInferrer diff --git a/openlrm/runners/infer/base_inferrer.py b/openlrm/runners/infer/base_inferrer.py new file mode 100644 index 0000000000000000000000000000000000000000..2725152266dbf52a13bf308a0df0e77e09aa2542 --- /dev/null +++ b/openlrm/runners/infer/base_inferrer.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from abc import abstractmethod +from accelerate import Accelerator +from accelerate.logging import get_logger + +from openlrm.runners.abstract import Runner + + +logger = get_logger(__name__) + + +class Inferrer(Runner): + + EXP_TYPE: str = None + + def __init__(self): + super().__init__() + + torch._dynamo.config.disable = True + self.accelerator = Accelerator() + + self.model : torch.nn.Module = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + @property + def device(self): + return self.accelerator.device + + @abstractmethod + def _build_model(self, cfg): + pass + + @abstractmethod + def infer_single(self, *args, **kwargs): + pass + + @abstractmethod + def infer(self): + pass + + def run(self): + self.infer() diff --git a/openlrm/runners/infer/lrm.py b/openlrm/runners/infer/lrm.py new file mode 100644 index 0000000000000000000000000000000000000000..2f90e19bc245638ee7ef6028f1c0f5e97dca7140 --- /dev/null +++ b/openlrm/runners/infer/lrm.py @@ -0,0 +1,403 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import os +import argparse +import mcubes +import trimesh +import safetensors +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +from tqdm.auto import tqdm +from accelerate.logging import get_logger +from huggingface_hub import hf_hub_download + +from .base_inferrer import Inferrer +from openlrm.datasets.cam_utils import build_camera_principle, build_camera_standard, surrounding_views_linspace, create_intrinsics +from openlrm.utils.logging import configure_logger +from openlrm.runners import REGISTRY_RUNNERS +from openlrm.utils.video import images_to_video +from openlrm.utils.hf_hub import wrap_model_hub + + +logger = get_logger(__name__) + + +def parse_configs(): + + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str) + parser.add_argument('--infer', type=str) + args, unknown = parser.parse_known_args() + + cfg = OmegaConf.create() + cli_cfg = OmegaConf.from_cli(unknown) + + # parse from ENV + if os.environ.get('APP_INFER') is not None: + args.infer = os.environ.get('APP_INFER') + if os.environ.get('APP_MODEL_NAME') is not None: + cli_cfg.model_name = os.environ.get('APP_MODEL_NAME') + if os.environ.get('APP_PRETRAIN_MODEL_NAME') is not None: + cli_cfg.pretrain_model_hf = os.environ.get('APP_PRETRAIN_MODEL_NAME') + + if args.config is not None: + cfg_train = OmegaConf.load(args.config) + cfg.source_size = cfg_train.dataset.source_image_res + cfg.render_size = cfg_train.dataset.render_image.high + _relative_path = os.path.join(cfg_train.experiment.parent, cfg_train.experiment.child, os.path.basename(cli_cfg.model_name).split('_')[-1]) + cfg.video_dump = os.path.join("exps", 'videos', _relative_path) + cfg.mesh_dump = os.path.join("exps", 'meshes', _relative_path) + + if args.infer is not None: + cfg_infer = OmegaConf.load(args.infer) + cfg.merge_with(cfg_infer) + if hasattr(cfg, 'experiment') and hasattr(cfg.experiment, 'parent'): + cfg.setdefault('video_dump', os.path.join("dumps", cli_cfg.model_name, cfg.experiment.parent, cfg.experiment.child, 'videos')) + cfg.setdefault('mesh_dump', os.path.join("dumps", cli_cfg.model_name, cfg.experiment.parent, cfg.experiment.child, 'meshes')) + else: + cfg.setdefault('video_dump', os.path.join("dumps", cli_cfg.model_name, 'videos')) + cfg.setdefault('mesh_dump', os.path.join("dumps", cli_cfg.model_name, 'meshes')) + + cfg.setdefault('double_sided', False) + cfg.setdefault('pretrain_model_hf', None) + cfg.merge_with(cli_cfg) + + """ + [required] + model_name: str + image_input: str + export_video: bool + export_mesh: bool + + [special] + source_size: int + render_size: int + video_dump: str + mesh_dump: str + + [default] + render_views: int + render_fps: int + mesh_size: int + mesh_thres: float + frame_size: int + logger: str + """ + + cfg.setdefault('inferrer', {}) + cfg['inferrer'].setdefault('logger', 'INFO') + + # assert not (args.config is not None and args.infer is not None), "Only one of config and infer should be provided" + assert cfg.model_name is not None, "model_name is required" + if not os.environ.get('APP_ENABLED', None): + assert cfg.image_input is not None, "image_input is required" + assert cfg.export_video or cfg.export_mesh, \ + "At least one of export_video or export_mesh should be True" + cfg.app_enabled = False + else: + cfg.app_enabled = True + + return cfg + + +@REGISTRY_RUNNERS.register('infer.lrm') +class LRMInferrer(Inferrer): + + EXP_TYPE: str = 'lrm' + + def __init__(self): + super().__init__() + + self.cfg = parse_configs() + configure_logger( + stream_level=self.cfg.inferrer.logger, + log_level=self.cfg.inferrer.logger, + ) + + self.model = self._build_model(self.cfg).to(self.device) + + def _load_checkpoint(self, cfg): + ckpt_root = os.path.join( + cfg.saver.checkpoint_root, + cfg.experiment.parent, cfg.experiment.child, + ) + if not os.path.exists(ckpt_root): + raise FileNotFoundError(f"The checkpoint directory '{ckpt_root}' does not exist.") + ckpt_dirs = os.listdir(ckpt_root) + iter_number = "{:06}".format(cfg.inferrer.iteration) + if iter_number not in ckpt_dirs: + raise FileNotFoundError(f"Checkpoint for iteration '{iter_number}' not found in '{ckpt_root}'.") + inferrer_ckpt_path = os.path.join(ckpt_root, iter_number, 'model.safetensors') + logger.info(f"======== Auto-resume from {inferrer_ckpt_path} ========") + return inferrer_ckpt_path + + def _build_model(self, cfg): + from openlrm.models import model_dict + if cfg.inferrer.hugging_face is True: # for huggingface infer + hf_model_cls = wrap_model_hub(model_dict[self.EXP_TYPE]) + model = hf_model_cls.from_pretrained(cfg.model_name) + if cfg.double_sided: + pretrain_model_path = hf_hub_download(repo_id=cfg.pretrain_model_hf, filename='model.safetensors') + safetensors.torch.load_model( # load the pretrain model after load the Tailor3D finetune part. + model, + pretrain_model_path, + strict=False + ) + else: # for common infer + model = model_dict[self.EXP_TYPE](**cfg['model']) + inferrer_ckpt_path = self._load_checkpoint(cfg) + if cfg.double_sided: + pretrain_model_path = hf_hub_download(repo_id=cfg.pretrain_model_hf, filename='model.safetensors') + safetensors.torch.load_model( # load the pretrain model. + model, + pretrain_model_path, + strict=False + ) + safetensors.torch.load_model( # load the finetune model. + model, + inferrer_ckpt_path, + strict=False + ) + else: + safetensors.torch.load_model( + model, + inferrer_ckpt_path, + ) + return model + + @staticmethod + def save_images(images, output_path): + os.makedirs((output_path), exist_ok=True) + for i in range(images.shape[0]): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + Image.fromarray(frame).save(os.path.join(output_path, f"{str(i)}.png")) + + def _default_source_camera(self, dist_to_center: float = 2.0, batch_size: int = 1, device: torch.device = torch.device('cpu')): + # return: (N, D_cam_raw) + canonical_camera_extrinsics = torch.tensor([[ + [1, 0, 0, 0], + [0, 0, -1, -dist_to_center], + [0, 1, 0, 0], + ]], dtype=torch.float32, device=device) + canonical_camera_intrinsics = create_intrinsics( + f=0.75, + c=0.5, + device=device, + ).unsqueeze(0) + source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics) + return source_camera.repeat(batch_size, 1) + + def _default_render_cameras(self, n_views: int, batch_size: int = 1, device: torch.device = torch.device('cpu')): + # return: (N, M, D_cam_render) + render_camera_extrinsics = surrounding_views_linspace(n_views=n_views, device=device) + render_camera_intrinsics = create_intrinsics( + f=0.75, + c=0.5, + device=device, + ).unsqueeze(0).repeat(render_camera_extrinsics.shape[0], 1, 1) + render_cameras = build_camera_standard(render_camera_extrinsics, render_camera_intrinsics) + return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1) + + def infer_planes(self, image: torch.Tensor, source_cam_dist: float, back_image=None): + N = image.shape[0] + source_camera = self._default_source_camera(dist_to_center=source_cam_dist, batch_size=N, device=self.device) + front_planes = self.model.forward_planes(image, source_camera) + if back_image is not None: + back_planes = self.model.forward_planes(back_image, source_camera) + # XY Plane + back_planes[:, 0, :, :, :] = torch.flip(back_planes[:, 0, :, :, :], dims=[-2, -1]) + # XZ Plane + back_planes[:, 1, :, :, :] = torch.flip(back_planes[:, 1, :, :, :], dims=[-1]) + # YZ Plane + back_planes[:, 2, :, :, :] = torch.flip(back_planes[:, 2, :, :, :], dims=[-2]) + + # To fuse the front planes and the back planes + bs, num_planes, channels, height, width = front_planes.shape + if 'conv_fuse' in self.cfg['model']: + planes = torch.cat((front_planes, back_planes), dim=2) + planes = planes.reshape(-1, channels*2, height, width) + # planes = self.model.front_back_conv(planes).view(bs, num_planes, -1, height, width) # only one layer. + # Apply multiple convolutional layers + for layer in self.model.front_back_conv: + planes = layer(planes) + + planes = planes.view(bs, num_planes, -1, height, width) + elif 'swin_ca_fuse' in self.cfg['model']: + front_planes = front_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() # [8, 3, 32, 64, 64] -> [24, 32, 4096] -> [24, 4096, 32] + back_planes = back_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() + planes = self.model.swin_cross_attention(front_planes, back_planes, height, width)[0].permute(0, 2, 1).reshape(bs, num_planes, channels, height, width) + else: + planes = front_planes + + assert N == planes.shape[0] + return planes + + def infer_video(self, planes: torch.Tensor, frame_size: int, render_size: int, render_views: int, render_fps: int, dump_video_path: str, image_format=False): + N = planes.shape[0] + render_cameras = self._default_render_cameras(n_views=render_views, batch_size=N, device=self.device) + render_anchors = torch.zeros(N, render_cameras.shape[1], 2, device=self.device) + render_resolutions = torch.ones(N, render_cameras.shape[1], 1, device=self.device) * render_size + render_bg_colors = torch.ones(N, render_cameras.shape[1], 1, device=self.device, dtype=torch.float32) * 1. + + frames = [] + for i in range(0, render_cameras.shape[1], frame_size): + frames.append( + self.model.synthesizer( + planes=planes, + cameras=render_cameras[:, i:i+frame_size], + anchors=render_anchors[:, i:i+frame_size], + resolutions=render_resolutions[:, i:i+frame_size], + bg_colors=render_bg_colors[:, i:i+frame_size], + region_size=render_size, + ) + ) + # merge frames + frames = { + k: torch.cat([r[k] for r in frames], dim=1) + for k in frames[0].keys() + } + # dump + os.makedirs(os.path.dirname(dump_video_path), exist_ok=True) + for k, v in frames.items(): + if k == 'images_rgb': + if image_format: + self.save_images( # save the rendering images directly. + v[0], + os.path.join(dump_video_path.replace('.mov', ''), 'nvs'), + ) + else: + images_to_video( + images=v[0], + output_path=dump_video_path, + fps=render_fps, + gradio_codec=self.cfg.app_enabled, + ) + + def infer_mesh(self, planes: torch.Tensor, mesh_size: int, mesh_thres: float, dump_mesh_path: str): + grid_out = self.model.synthesizer.forward_grid( + planes=planes, + grid_size=mesh_size, + ) + + vtx, faces = mcubes.marching_cubes(grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), mesh_thres) + vtx = vtx / (mesh_size - 1) * 2 - 1 + + vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0) + vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1) + vtx_colors = (vtx_colors * 255).astype(np.uint8) + + mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) + + # dump + os.makedirs(os.path.dirname(dump_mesh_path), exist_ok=True) + mesh.export(dump_mesh_path) + + def infer_single(self, image_path: str, source_cam_dist: float, export_video: bool, export_mesh: bool, dump_video_path: str, dump_mesh_path: str, image_path_back=None): + source_size = self.cfg.inferrer.source_size + render_size = self.cfg.inferrer.render_size + render_views = self.cfg.inferrer.render_views + render_fps = self.cfg.inferrer.render_fps + mesh_size = self.cfg.inferrer.mesh_size + mesh_thres = self.cfg.inferrer.mesh_thres + frame_size = self.cfg.inferrer.frame_size + source_cam_dist = self.cfg.inferrer.source_cam_dist if source_cam_dist is None else source_cam_dist + + image_format = self.cfg.inferrer.image_format + + image = self.open_image(image_path, source_size) + if image_path_back is None: + back_image = self.open_image(image_path.replace('front', 'back'), source_size) if self.cfg.double_sided else None + else: + back_image = self.open_image(image_path_back, source_size) if self.cfg.double_sided else None + + with torch.no_grad(): + planes = self.infer_planes(image, source_cam_dist=source_cam_dist, back_image=back_image) + + results = {} + if export_video: + frames = self.infer_video(planes, frame_size=frame_size, render_size=render_size, render_views=render_views, render_fps=render_fps, dump_video_path=dump_video_path, + image_format=image_format) + results.update({ + 'frames': frames, + }) + if export_mesh: + mesh = self.infer_mesh(planes, mesh_size=mesh_size, mesh_thres=mesh_thres, dump_mesh_path=dump_mesh_path) + results.update({ + 'mesh': mesh, + }) + + def data_init(self): + image_paths = [] + if os.path.isfile(self.cfg.image_input): + omit_prefix = os.path.dirname(self.cfg.image_input) + image_paths.append(self.cfg.image_input) + else: + omit_prefix = self.cfg.image_input + if self.cfg.double_sided: # double sided + walk_path = os.path.join(self.cfg.image_input, 'front') + else: + walk_path = self.cfg.image_input + for root, dirs, files in os.walk(walk_path): + for file in files: + if file.endswith('.png'): + image_paths.append(os.path.join(root, file)) + image_paths.sort() + # alloc to each DDP worker + image_paths = image_paths[self.accelerator.process_index::self.accelerator.num_processes] + + return image_paths, omit_prefix + + def open_image(self, image_path, source_size): + # prepare image: [1, C_img, H_img, W_img], 0-1 scale + image = torch.from_numpy(np.array(Image.open(image_path))).to(self.device) + image = image.permute(2, 0, 1).unsqueeze(0) / 255.0 + if image.shape[1] == 4: # RGBA + image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...]) + image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True) + image = torch.clamp(image, 0, 1) + + return image + + def infer(self): + image_paths, omit_prefix = self.data_init() + for image_path in tqdm(image_paths, disable=not self.accelerator.is_local_main_process): + + # prepare dump paths + image_name = os.path.basename(image_path) + uid = image_name.split('.')[0] + subdir_path = os.path.dirname(image_path).replace(omit_prefix, '') + subdir_path = subdir_path[1:] if subdir_path.startswith('/') else subdir_path + dump_video_path = os.path.join( + self.cfg.video_dump, + subdir_path, + f'{uid}.mov', + ) + dump_mesh_path = os.path.join( + self.cfg.mesh_dump, + subdir_path, + f'{uid}.ply', + ) + + self.infer_single( + image_path, + source_cam_dist=None, + export_video=self.cfg.export_video, + export_mesh=self.cfg.export_mesh, + dump_video_path=dump_video_path, + dump_mesh_path=dump_mesh_path, + ) diff --git a/openlrm/runners/train/__init__.py b/openlrm/runners/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e837a0ba6fd8894e3fec916b72e9543ebc2b3db2 --- /dev/null +++ b/openlrm/runners/train/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .lrm import LRMTrainer diff --git a/openlrm/runners/train/base_trainer.py b/openlrm/runners/train/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..420455e308e94a6768c8fef93b74c5c5a38b4305 --- /dev/null +++ b/openlrm/runners/train/base_trainer.py @@ -0,0 +1,385 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import time +import math +import argparse +import shutil +import torch +import safetensors +from omegaconf import OmegaConf +from abc import abstractmethod +from contextlib import contextmanager +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed + +from openlrm.utils.logging import configure_logger +from openlrm.utils.compile import configure_dynamo +from openlrm.runners.abstract import Runner + +from collections import OrderedDict +from huggingface_hub import hf_hub_download + +# def my_save_pre_hook(models, weights, output_dir): +# keep = ["_lora", "synthesizer", "front_back_conv"] +# for weight_dict in weights: +# keys_to_keep = [key for key in weight_dict if any(keep_str in key for keep_str in keep)] +# new_weight_dict = OrderedDict((key, weight_dict[key]) for key in keys_to_keep) +# weight_dict.clear() +# weight_dict.update(new_weight_dict) + +from collections import OrderedDict + +def my_save_pre_hook(models, weights, output_dir): + assert len(models) == len(weights), "Models and weights must correspond one-to-one" + + filtered_weights_list = [] + for model, model_weights in zip(models, weights): + filtered_weights = OrderedDict() + for name, param in model.named_parameters(): + if param.requires_grad: + if name in model_weights: + filtered_weights[name] = model_weights[name] + + filtered_weights_list.append(filtered_weights) + + weights.clear() + weights.extend(filtered_weights_list) + + +logger = get_logger(__name__) + + +def parse_configs(): + # Define argparse arguments + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='./assets/config.yaml') + args, unknown = parser.parse_known_args() + + # Load configuration file + cfg = OmegaConf.load(args.config) + + # Override with command-line arguments + cli_cfg = OmegaConf.from_cli(unknown) + cfg = OmegaConf.merge(cfg, cli_cfg) + + return cfg + +class Trainer(Runner): + + def __init__(self): + super().__init__() + + self.cfg = parse_configs() + self.timestamp = time.strftime("%Y%m%d-%H%M%S") + + self.accelerator = Accelerator( + mixed_precision=self.cfg.train.mixed_precision, + gradient_accumulation_steps=self.cfg.train.accum_steps, + log_with=tuple(self.cfg.logger.trackers), + project_config=ProjectConfiguration( + logging_dir=self.cfg.logger.tracker_root, + ), + use_seedable_sampler=True, + kwargs_handlers=[ + DistributedDataParallelKwargs( + find_unused_parameters=self.cfg.train.find_unused_parameters, + ), + ], + ) + self.accelerator.register_save_state_pre_hook(my_save_pre_hook) # it is the save model hook. + + set_seed(self.cfg.experiment.seed, device_specific=True) + with self.accelerator.main_process_first(): + configure_logger( + stream_level=self.cfg.logger.stream_level, + log_level=self.cfg.logger.log_level, + file_path=os.path.join( + self.cfg.logger.log_root, + self.cfg.experiment.parent, self.cfg.experiment.child, + f"{self.timestamp}.log", + ) if self.accelerator.is_main_process else None, + ) + logger.info(self.accelerator.state, main_process_only=False, in_order=True) + configure_dynamo(dict(self.cfg.compile)) + + # attributes with defaults + self.model : torch.nn.Module = None + self.optimizer: torch.optim.Optimizer = None + self.scheduler: torch.optim.lr_scheduler.LRScheduler = None + self.train_loader: torch.utils.data.DataLoader = None + self.val_loader: torch.utils.data.DataLoader = None + self.N_max_global_steps: int = None + self.N_global_steps_per_epoch: int = None + self.global_step: int = 0 + self.current_epoch: int = 0 + + def __enter__(self): + self.accelerator.init_trackers( + project_name=f"{self.cfg.experiment.parent}/{self.cfg.experiment.child}", + ) + self.prepare_everything() + self.log_inital_info() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.accelerator.end_training() + + @staticmethod + def control(option: str = None, synchronized: bool = False): + def decorator(func): + def wrapper(self, *args, **kwargs): + if option is None or hasattr(self.accelerator, option): + accelerated_func = getattr(self.accelerator, option)(func) if option is not None else func + result = accelerated_func(self, *args, **kwargs) + if synchronized: + self.accelerator.wait_for_everyone() + return result + else: + raise AttributeError(f"Accelerator has no attribute {option}") + return wrapper + return decorator + + @contextmanager + def exec_in_order(self): + for rank in range(self.accelerator.num_processes): + try: + if self.accelerator.process_index == rank: + yield + finally: + self.accelerator.wait_for_everyone() + + @property + def device(self): + return self.accelerator.device + + @property + def is_distributed(self) -> bool: + return self.accelerator.num_processes > 1 + + def prepare_everything(self, is_dist_validation: bool = True): + # prepare with accelerator + if is_dist_validation: + self.model, self.optimizer, self.train_loader, self.val_loader = \ + self.accelerator.prepare( + self.model, self.optimizer, self.train_loader, self.val_loader, + ) + else: + self.model, self.optimizer, self.train_loader = \ + self.accelerator.prepare( + self.model, self.optimizer, self.train_loader, + ) + self.accelerator.register_for_checkpointing(self.scheduler) + # prepare stats + N_total_batch_size = self.cfg.train.batch_size * self.accelerator.num_processes * self.cfg.train.accum_steps + self.N_global_steps_per_epoch = math.ceil(len(self.train_loader) / self.cfg.train.accum_steps) + self.N_max_global_steps = self.N_global_steps_per_epoch * self.cfg.train.epochs + if self.cfg.train.debug_global_steps is not None: + logger.warning(f"Overriding max global steps from {self.N_max_global_steps} to {self.cfg.train.debug_global_steps}") + self.N_max_global_steps = self.cfg.train.debug_global_steps + logger.info(f"======== Statistics ========") + logger.info(f"** N_max_global_steps: {self.N_max_global_steps}") + logger.info(f"** N_total_batch_size: {N_total_batch_size}") + logger.info(f"** N_epochs: {self.cfg.train.epochs}") + logger.info(f"** N_global_steps_per_epoch: {self.N_global_steps_per_epoch}") + logger.debug(f"** Prepared loader length: {len(self.train_loader)}") + logger.info(f"** Distributed validation: {is_dist_validation}") + logger.info(f"============================") + logger.info(f"======== Trainable parameters ========") + logger.info(f"** Total: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") + for sub_name, sub_module in self.accelerator.unwrap_model(self.model).named_children(): + logger.info(f"** {sub_name}: {sum(p.numel() for p in sub_module.parameters() if p.requires_grad)}") + logger.info(f"=====================================") + self.accelerator.wait_for_everyone() + # load checkpoint or model + self.load_ckpt_or_auto_resume_(self.cfg) + # register hooks + self.register_hooks() + + @abstractmethod + def register_hooks(self): + pass + + def auto_resume_(self, cfg) -> bool: + ckpt_root = os.path.join( + cfg.saver.checkpoint_root, + cfg.experiment.parent, cfg.experiment.child, + ) + if not os.path.exists(ckpt_root): + return False + ckpt_dirs = os.listdir(ckpt_root) + if len(ckpt_dirs) == 0: + return False + ckpt_dirs.sort() + latest_ckpt = ckpt_dirs[-1] + latest_ckpt_dir = os.path.join(ckpt_root, latest_ckpt) + logger.info(f"======== Auto-resume from {latest_ckpt_dir} ========") + self.accelerator.load_state(latest_ckpt_dir, strict=cfg.saver.load_model_func_kwargs.strict) + self.global_step = int(latest_ckpt) + self.current_epoch = self.global_step // self.N_global_steps_per_epoch + return True + + def load_model_(self, cfg): + if cfg.saver.load_model.type == 'hugging_face': + repo_id, file_name = os.path.dirname(cfg.saver.load_model.url), os.path.basename(cfg.saver.load_model.url) + pretrain_model_path = hf_hub_download(repo_id=repo_id, filename=file_name) + logger.info(f"======== Loading pretrain model from hugging face {repo_id, file_name} ========") + safetensors.torch.load_model( + self.accelerator.unwrap_model(self.model), + pretrain_model_path, + **cfg.saver.load_model_func_kwargs + ) + logger.info(f"======== Pretrain Model loaded ========") + return True + else: + logger.info(f"======== Loading model from {cfg.saver.load_model} ========") + safetensors.torch.load_model( + self.accelerator.unwrap_model(self.model), + cfg.saver.load_model, + strict=True, + ) + logger.info(f"======== Model loaded ========") + return True + + @control(synchronized=True) + def load_ckpt_or_auto_resume_(self, cfg): + # auto resume has higher priority, load model from path if auto resume is not available + # cfg.saver.auto_resume and cfg.saver.load_model + if cfg.saver.auto_resume: + successful_resume = self.auto_resume_(cfg) + if successful_resume: + if cfg.saver.load_model: + successful_load = self.load_model_(cfg) + if successful_load: + return + return + if cfg.saver.load_model: + successful_load = self.load_model_(cfg) + if successful_load: + return + logger.debug(f"======== No checkpoint or model is loaded ========") + + @control('on_main_process', synchronized=True) + def save_checkpoint(self): + ckpt_dir = os.path.join( + self.cfg.saver.checkpoint_root, + self.cfg.experiment.parent, self.cfg.experiment.child, + f"{self.global_step:06d}", + ) + self.accelerator.save_state(output_dir=ckpt_dir, safe_serialization=True) + logger.info(f"======== Saved checkpoint at global step {self.global_step} ========") + # manage stratified checkpoints + ckpt_dirs = os.listdir(os.path.dirname(ckpt_dir)) + ckpt_dirs.sort() + max_ckpt = int(ckpt_dirs[-1]) + ckpt_base = int(self.cfg.saver.checkpoint_keep_level) + ckpt_period = self.cfg.saver.checkpoint_global_steps + logger.debug(f"Checkpoint base: {ckpt_base}") + logger.debug(f"Checkpoint period: {ckpt_period}") + cur_order = ckpt_base ** math.floor(math.log(max_ckpt // ckpt_period, ckpt_base)) + cur_idx = 0 + while cur_order > 0: + cur_digit = max_ckpt // ckpt_period // cur_order % ckpt_base + while cur_idx < len(ckpt_dirs) and int(ckpt_dirs[cur_idx]) // ckpt_period // cur_order % ckpt_base < cur_digit: + if int(ckpt_dirs[cur_idx]) // ckpt_period % cur_order != 0: + shutil.rmtree(os.path.join(os.path.dirname(ckpt_dir), ckpt_dirs[cur_idx])) + logger.info(f"Removed checkpoint {ckpt_dirs[cur_idx]}") + cur_idx += 1 + cur_order //= ckpt_base + + @property + def global_step_in_epoch(self): + return self.global_step % self.N_global_steps_per_epoch + + @abstractmethod + def _build_model(self): + pass + + @abstractmethod + def _build_optimizer(self): + pass + + @abstractmethod + def _build_scheduler(self): + pass + + @abstractmethod + def _build_dataloader(self): + pass + + @abstractmethod + def _build_loss_fn(self): + pass + + @abstractmethod + def train(self): + pass + + @abstractmethod + def evaluate(self): + pass + + @staticmethod + def _get_str_progress(epoch: int = None, step: int = None): + if epoch is not None: + log_type = 'epoch' + log_progress = epoch + elif step is not None: + log_type = 'step' + log_progress = step + else: + raise ValueError('Either epoch or step must be provided') + return log_type, log_progress + + @control('on_main_process') + def log_scalar_kwargs(self, epoch: int = None, step: int = None, split: str = None, **scalar_kwargs): + log_type, log_progress = self._get_str_progress(epoch, step) + split = f'/{split}' if split else '' + for key, value in scalar_kwargs.items(): + self.accelerator.log({f'{key}{split}/{log_type}': value}, log_progress) + + @control('on_main_process') + def log_images(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}): + for tracker in self.accelerator.trackers: + if hasattr(tracker, 'log_images'): + tracker.log_images(values, step=step, **log_kwargs.get(tracker.name, {})) + + @control('on_main_process') + def log_optimizer(self, epoch: int = None, step: int = None, attrs: list[str] = [], group_ids: list[int] = []): + log_type, log_progress = self._get_str_progress(epoch, step) + assert self.optimizer is not None, 'Optimizer is not initialized' + if not attrs: + logger.warning('No optimizer attributes are provided, nothing will be logged') + if not group_ids: + logger.warning('No optimizer group ids are provided, nothing will be logged') + for attr in attrs: + assert attr in ['lr', 'momentum', 'weight_decay'], f'Invalid optimizer attribute {attr}' + for group_id in group_ids: + self.accelerator.log({f'opt/{attr}/{group_id}': self.optimizer.param_groups[group_id][attr]}, log_progress) + + @control('on_main_process') + def log_inital_info(self): + assert self.model is not None, 'Model is not initialized' + assert self.optimizer is not None, 'Optimizer is not initialized' + assert self.scheduler is not None, 'Scheduler is not initialized' + self.accelerator.log({'Config': "```\n" + OmegaConf.to_yaml(self.cfg) + "\n```"}) + self.accelerator.log({'Model': "```\n" + str(self.model) + "\n```"}) + self.accelerator.log({'Optimizer': "```\n" + str(self.optimizer) + "\n```"}) + self.accelerator.log({'Scheduler': "```\n" + str(self.scheduler) + "\n```"}) + + def run(self): + self.train() diff --git a/openlrm/runners/train/lrm.py b/openlrm/runners/train/lrm.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5933d006a7a755cdc45e6ebc101127dd8b5c66 --- /dev/null +++ b/openlrm/runners/train/lrm.py @@ -0,0 +1,427 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import math +from tqdm.auto import tqdm +import torch +import torch.nn as nn +from torchvision.utils import make_grid +from accelerate.logging import get_logger + +from .base_trainer import Trainer +from openlrm.utils.profiler import DummyProfiler +from openlrm.runners import REGISTRY_RUNNERS + + +logger = get_logger(__name__) + + +@REGISTRY_RUNNERS.register('train.lrm') +class LRMTrainer(Trainer): + def __init__(self): + super().__init__() + + self.model = self._build_model(self.cfg) + self.optimizer = self._build_optimizer(self.model, self.cfg) + self.train_loader, self.val_loader = self._build_dataloader(self.cfg) + self.scheduler = self._build_scheduler(self.optimizer, self.cfg) + self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg) + + def _build_model(self, cfg): + assert cfg.experiment.type == 'lrm', \ + f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}" + from openlrm.models import ModelLRM + model = ModelLRM(**cfg.model) + return model + + def _build_optimizer(self, model: nn.Module, cfg): + decay_params, no_decay_params = [], [] + + # add all bias and LayerNorm params to no_decay_params + for name, module in model.named_modules(): + if isinstance(module, nn.LayerNorm): + no_decay_params.extend([p for p in module.parameters()]) + elif hasattr(module, 'bias') and module.bias is not None: + no_decay_params.append(module.bias) + + # add remaining parameters to decay_params + _no_decay_ids = set(map(id, no_decay_params)) + decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids] + + # filter out parameters with no grad + decay_params = list(filter(lambda p: p.requires_grad, decay_params)) + no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params)) + + # monitor this to make sure we don't miss any parameters + logger.info("======== Weight Decay Parameters ========") + logger.info(f"Total: {len(decay_params)}") + logger.info("======== No Weight Decay Parameters ========") + logger.info(f"Total: {len(no_decay_params)}") + + # Optimizer + opt_groups = [ + {'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay}, + {'params': no_decay_params, 'weight_decay': 0.0}, + ] + optimizer = torch.optim.AdamW( + opt_groups, + lr=cfg.train.optim.lr, + betas=(cfg.train.optim.beta1, cfg.train.optim.beta2), + ) + + return optimizer + + def _build_scheduler(self, optimizer, cfg): + local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes) + total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps) + effective_warmup_iters = cfg.train.scheduler.warmup_real_iters + logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========") + logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========") + if cfg.train.scheduler.type == 'cosine': + from openlrm.utils.scheduler import CosineWarmupScheduler + scheduler = CosineWarmupScheduler( + optimizer=optimizer, + warmup_iters=effective_warmup_iters, + max_iters=total_global_batches, + ) + else: + raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented") + return scheduler + + def _build_dataloader(self, cfg): + # dataset class + from openlrm.datasets import MixerDataset + + # build dataset + train_dataset = MixerDataset( + split="train", + subsets=cfg.dataset.subsets, + sample_side_views=cfg.dataset.sample_side_views, + render_image_res_low=cfg.dataset.render_image.low, + render_image_res_high=cfg.dataset.render_image.high, + render_region_size=cfg.dataset.render_image.region, + source_image_res=cfg.dataset.source_image_res, + normalize_camera=cfg.dataset.normalize_camera, + normed_dist_to_center=cfg.dataset.normed_dist_to_center, + ) + val_dataset = MixerDataset( + split="val", + subsets=cfg.dataset.subsets, + sample_side_views=cfg.dataset.sample_side_views, + render_image_res_low=cfg.dataset.render_image.low, + render_image_res_high=cfg.dataset.render_image.high, + render_region_size=cfg.dataset.render_image.region, + source_image_res=cfg.dataset.source_image_res, + normalize_camera=cfg.dataset.normalize_camera, + normed_dist_to_center=cfg.dataset.normed_dist_to_center, + ) + + # build data loader + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=cfg.train.batch_size, + shuffle=True, + drop_last=True, + num_workers=cfg.dataset.num_train_workers, + pin_memory=cfg.dataset.pin_mem, + persistent_workers=True, + ) + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=cfg.val.batch_size, + shuffle=False, + drop_last=False, + num_workers=cfg.dataset.num_val_workers, + pin_memory=cfg.dataset.pin_mem, + persistent_workers=False, + ) + + return train_loader, val_loader + + def _build_loss_fn(self, cfg): + from openlrm.losses import PixelLoss, LPIPSLoss, TVLoss + pixel_loss_fn = PixelLoss() + with self.accelerator.main_process_first(): + perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True) + tv_loss_fn = TVLoss() + return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn + + def register_hooks(self): + pass + + def forward_loss_local_step(self, data): + + source_camera = data['source_camera'] + render_camera = data['render_camera'] + source_image = data['source_image'] + render_image = data['render_image'] + if 'source_image_back' in data: + source_image_back = data['source_image_back'] #!!! + else: + source_image_back = None + render_anchors = data['render_anchors'] + render_full_resolutions = data['render_full_resolutions'] + render_bg_colors = data['render_bg_colors'] + + N, M, C, H, W = render_image.shape + + # forward + outputs = self.model( + image=source_image, + source_camera=source_camera, + render_cameras=render_camera, + render_anchors=render_anchors, + render_resolutions=render_full_resolutions, + render_bg_colors=render_bg_colors, + render_region_size=self.cfg.dataset.render_image.region, + image_back=source_image_back, #!!! + ) + + # loss calculation + loss = 0. + loss_pixel = None + loss_perceptual = None + loss_tv = None + + if self.cfg.train.loss.pixel_weight > 0.: + loss_pixel = self.pixel_loss_fn(outputs['images_rgb'], render_image) + loss += loss_pixel * self.cfg.train.loss.pixel_weight + if self.cfg.train.loss.perceptual_weight > 0.: + loss_perceptual = self.perceptual_loss_fn(outputs['images_rgb'], render_image) + loss += loss_perceptual * self.cfg.train.loss.perceptual_weight + if self.cfg.train.loss.tv_weight > 0.: + loss_tv = self.tv_loss_fn(outputs['planes']) + loss += loss_tv * self.cfg.train.loss.tv_weight + + return outputs, loss, loss_pixel, loss_perceptual, loss_tv + + def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile): + self.model.train() + + local_step_losses = [] + global_step_losses = [] + + logger.debug(f"======== Starting epoch {self.current_epoch} ========") + for data in loader: + + logger.debug(f"======== Starting global step {self.global_step} ========") + with self.accelerator.accumulate(self.model): + + # forward to loss + outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data) + + # backward + self.accelerator.backward(loss) + if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: + self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm) + self.optimizer.step() + self.optimizer.zero_grad() + + # track local losses + local_step_losses.append(torch.stack([ + _loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device) + for _loss in [loss, loss_pixel, loss_perceptual, loss_tv] + ])) + + # track global step + if self.accelerator.sync_gradients: + profiler.step() + self.scheduler.step() + logger.debug(f"======== Scheduler step ========") + self.global_step += 1 + global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu() + loss, loss_pixel, loss_perceptual, loss_tv = global_step_loss.unbind() + loss_kwargs = { + 'loss': loss.item(), + 'loss_pixel': loss_pixel.item(), + 'loss_perceptual': loss_perceptual.item(), + 'loss_tv': loss_tv.item(), + } + self.log_scalar_kwargs( + step=self.global_step, split='train', + **loss_kwargs + ) + self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1]) + local_step_losses = [] + global_step_losses.append(global_step_loss) + + # manage display + pbar.update(1) + description = { + **loss_kwargs, + 'lr': self.optimizer.param_groups[0]['lr'], + } + description = '[TRAIN STEP]' + \ + ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v)) + pbar.set_description(description) + + # periodic actions + if self.global_step % self.cfg.saver.checkpoint_global_steps == 0: + self.save_checkpoint() + if self.global_step % self.cfg.val.global_step_period == 0: + self.evaluate() + self.model.train() + if self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0: + self.log_image_monitor( + step=self.global_step, split='train', + renders=outs['images_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(), + gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), + ) + + # progress control + if self.global_step >= self.N_max_global_steps: + self.accelerator.set_trigger() + break + + # track epoch + self.current_epoch += 1 + epoch_losses = torch.stack(global_step_losses).mean(dim=0) + epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv = epoch_losses.unbind() + epoch_loss_dict = { + 'loss': epoch_loss.item(), + 'loss_pixel': epoch_loss_pixel.item(), + 'loss_perceptual': epoch_loss_perceptual.item(), + 'loss_tv': epoch_loss_tv.item(), + } + self.log_scalar_kwargs( + epoch=self.current_epoch, split='train', + **epoch_loss_dict, + ) + logger.info( + f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \ + ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v)) + ) + + def train(self): + + starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps + skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch) + logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========") + + with tqdm( + range(0, self.N_max_global_steps), + initial=self.global_step, + disable=(not self.accelerator.is_main_process), + ) as pbar: + + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + wait=10, warmup=10, active=100, + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join( + self.cfg.logger.tracker_root, + self.cfg.experiment.parent, self.cfg.experiment.child, + )), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) if self.cfg.logger.enable_profiler else DummyProfiler() + + with profiler: + + self.optimizer.zero_grad() + for _ in range(self.current_epoch, self.cfg.train.epochs): + + loader = skipped_loader or self.train_loader + skipped_loader = None + self.train_epoch(pbar=pbar, loader=loader, profiler=profiler) + if self.accelerator.check_trigger(): + break + + logger.info(f"======== Training finished at global step {self.global_step} ========") + + # final checkpoint and evaluation + self.save_checkpoint() + self.evaluate() + + @torch.no_grad() + @torch.compiler.disable + def evaluate(self, epoch: int = None): + self.model.eval() + + max_val_batches = self.cfg.val.debug_batches or len(self.val_loader) + running_losses = [] + sample_data, sample_outs = None, None + + for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches): + + if len(running_losses) >= max_val_batches: + logger.info(f"======== Early stop validation at {len(running_losses)} batches ========") + break + + outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data) + sample_data, sample_outs = data, outs + + running_losses.append(torch.stack([ + _loss if _loss is not None else torch.tensor(float('nan'), device=self.device) + for _loss in [loss, loss_pixel, loss_perceptual, loss_tv] + ])) + + total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu() + total_loss, total_loss_pixel, total_loss_perceptual, total_loss_tv = total_losses.unbind() + total_loss_dict = { + 'loss': total_loss.item(), + 'loss_pixel': total_loss_pixel.item(), + 'loss_perceptual': total_loss_perceptual.item(), + 'loss_tv': total_loss_tv.item(), + } + + if epoch is not None: + self.log_scalar_kwargs( + epoch=epoch, split='val', + **total_loss_dict, + ) + logger.info( + f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \ + ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) + ) + self.log_image_monitor( + epoch=epoch, split='val', + renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), + gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), + ) + else: + self.log_scalar_kwargs( + step=self.global_step, split='val', + **total_loss_dict, + ) + logger.info( + f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \ + ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) + ) + self.log_image_monitor( + step=self.global_step, split='val', + renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), + gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), + ) + + @Trainer.control('on_main_process') + def log_image_monitor( + self, epoch: int = None, step: int = None, split: str = None, + renders: torch.Tensor = None, gts: torch.Tensor = None, + ): + M = renders.shape[1] + merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:]) + renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:]) + renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M) + log_type, log_progress = self._get_str_progress(epoch, step) + split = f'/{split}' if split else '' + self.log_images({ + f'Images_split{split}/rendered': renders.unsqueeze(0), + f'Images_split{split}/gt': gts.unsqueeze(0), + f'Images_merged{split}': merged.unsqueeze(0), + }, log_progress) diff --git a/openlrm/utils/__init__.py b/openlrm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1e39e624fbf5d970acc4b05714f8b9f70830c6 --- /dev/null +++ b/openlrm/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Empty diff --git a/openlrm/utils/compile.py b/openlrm/utils/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..08972a23daf1c046c327ce93fc667b706a3ec65b --- /dev/null +++ b/openlrm/utils/compile.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from accelerate.logging import get_logger + + +logger = get_logger(__name__) + + +def configure_dynamo(config: dict): + try: + import torch._dynamo + logger.debug(f'Configuring torch._dynamo.config with {config}') + for k, v in config.items(): + if v is None: + logger.debug(f'Skipping torch._dynamo.config.{k} with None') + continue + if hasattr(torch._dynamo.config, k): + logger.warning(f'Overriding torch._dynamo.config.{k} from {getattr(torch._dynamo.config, k)} to {v}') + setattr(torch._dynamo.config, k, v) + except ImportError: + logger.debug('torch._dynamo not found, skipping') + pass diff --git a/openlrm/utils/hf_hub.py b/openlrm/utils/hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ba0df56983a407d20c2c656a82c1ad15487ca5 --- /dev/null +++ b/openlrm/utils/hf_hub.py @@ -0,0 +1,25 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin + + +def wrap_model_hub(model_cls: nn.Module): + class HfModel(model_cls, PyTorchModelHubMixin): + def __init__(self, config: dict): + super().__init__(**config) + self.config = config + return HfModel diff --git a/openlrm/utils/logging.py b/openlrm/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..6e2ecd77ff0d1dc9b7fa5cb4efc6edcda8a18d0d --- /dev/null +++ b/openlrm/utils/logging.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import logging +from tqdm.auto import tqdm + + +class TqdmStreamHandler(logging.StreamHandler): + def emit(self, record): + tqdm.write(self.format(record)) + + +def configure_logger(stream_level, log_level, file_path = None): + _stream_level = stream_level.upper() + _log_level = log_level.upper() + _project_level = _log_level + + _formatter = logging.Formatter("[%(asctime)s] %(name)s: [%(levelname)s] %(message)s") + + _stream_handler = TqdmStreamHandler() + _stream_handler.setLevel(_stream_level) + _stream_handler.setFormatter(_formatter) + + if file_path is not None: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + _file_handler = logging.FileHandler(file_path) + _file_handler.setLevel(_log_level) + _file_handler.setFormatter(_formatter) + + _project_logger = logging.getLogger(__name__.split('.')[0]) + _project_logger.setLevel(_project_level) + _project_logger.addHandler(_stream_handler) + if file_path is not None: + _project_logger.addHandler(_file_handler) diff --git a/openlrm/utils/preprocess.py b/openlrm/utils/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..81ea7663c6040f27ff61917f0c85fe72c488f898 --- /dev/null +++ b/openlrm/utils/preprocess.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import rembg +import cv2 +import os + +def save_image_with_directory_check(save_path, image): + directory = os.path.dirname(save_path) + + if not os.path.exists(directory): + os.makedirs(directory) + + return cv2.imwrite(save_path, image) + +class Preprocessor: + + """ + Preprocessing under cv2 conventions. + """ + + def __init__(self): + self.rembg_session = rembg.new_session( + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + + def preprocess(self, image_path: str, save_path: str, rmbg: bool = True, recenter: bool = True, size: int = 512, border_ratio: float = 0.2): + image = self.step_load_to_size(image_path=image_path, size=size*2) + if rmbg: + image = self.step_rembg(image_in=image) + else: + image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA) + if recenter: + image = self.step_recenter(image_in=image, border_ratio=border_ratio, square_size=size) + else: + image = cv2.resize( + src=image, + dsize=(size, size), + interpolation=cv2.INTER_AREA, + ) + return save_image_with_directory_check(save_path, image) + + def step_rembg(self, image_in: np.ndarray) -> np.ndarray: + image_out = rembg.remove( + data=image_in, + session=self.rembg_session, + ) + return image_out + + def step_recenter(self, image_in: np.ndarray, border_ratio: float, square_size: int) -> np.ndarray: + assert image_in.shape[-1] == 4, "Image to recenter must be RGBA" + mask = image_in[..., -1] > 0 + ijs = np.nonzero(mask) + # find bbox + i_min, i_max = ijs[0].min(), ijs[0].max() + j_min, j_max = ijs[1].min(), ijs[1].max() + bbox_height, bbox_width = i_max - i_min, j_max - j_min + # recenter and resize + desired_size = int(square_size * (1 - border_ratio)) + scale = desired_size / max(bbox_height, bbox_width) + desired_height, desired_width = int(bbox_height * scale), int(bbox_width * scale) + desired_i_min, desired_j_min = (square_size - desired_height) // 2, (square_size - desired_width) // 2 + desired_i_max, desired_j_max = desired_i_min + desired_height, desired_j_min + desired_width + # create new image + image_out = np.zeros((square_size, square_size, 4), dtype=np.uint8) + image_out[desired_i_min:desired_i_max, desired_j_min:desired_j_max] = cv2.resize( + src=image_in[i_min:i_max, j_min:j_max], + dsize=(desired_width, desired_height), + interpolation=cv2.INTER_AREA, + ) + return image_out + + def step_load_to_size(self, image_path: str, size: int) -> np.ndarray: + image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) + height, width = image.shape[:2] + scale = size / max(height, width) + height, width = int(height * scale), int(width * scale) + image_out = cv2.resize( + src=image, + dsize=(width, height), + interpolation=cv2.INTER_AREA, + ) + return image_out diff --git a/openlrm/utils/profiler.py b/openlrm/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..92ba79973308b627d5b20bdd7bb09eac138c93ad --- /dev/null +++ b/openlrm/utils/profiler.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from torch.profiler import profile + + +class DummyProfiler(profile): + def __init__(self): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def step(self): + pass diff --git a/openlrm/utils/proxy.py b/openlrm/utils/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0c6e53f0a0c412debc866d172926a4c0401bba --- /dev/null +++ b/openlrm/utils/proxy.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +NO_PROXY = "OPENLRM_NO_DATA_PROXY" in os.environ + +def no_proxy(func): + """Decorator to disable proxy but then restore after the function call.""" + def wrapper(*args, **kwargs): + # http_proxy, https_proxy, HTTP_PROXY, HTTPS_PROXY, all_proxy + http_proxy = os.environ.get('http_proxy') + https_proxy = os.environ.get('https_proxy') + HTTP_PROXY = os.environ.get('HTTP_PROXY') + HTTPS_PROXY = os.environ.get('HTTPS_PROXY') + all_proxy = os.environ.get('all_proxy') + os.environ['http_proxy'] = '' + os.environ['https_proxy'] = '' + os.environ['HTTP_PROXY'] = '' + os.environ['HTTPS_PROXY'] = '' + os.environ['all_proxy'] = '' + try: + return func(*args, **kwargs) + finally: + os.environ['http_proxy'] = http_proxy + os.environ['https_proxy'] = https_proxy + os.environ['HTTP_PROXY'] = HTTP_PROXY + os.environ['HTTPS_PROXY'] = HTTPS_PROXY + os.environ['all_proxy'] = all_proxy + if NO_PROXY: + return wrapper + else: + return func diff --git a/openlrm/utils/registry.py b/openlrm/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..421a735f82899c50884cd5b5a27e71757b2eb813 --- /dev/null +++ b/openlrm/utils/registry.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Registry: + """Registry class""" + + def __init__(self): + self._registry = {} + + def register(self, name): + """Register a module""" + def decorator(cls): + assert name not in self._registry, 'Module {} already registered'.format(name) + self._registry[name] = cls + return cls + return decorator + + def __getitem__(self, name): + """Get a module""" + return self._registry[name] + + def __contains__(self, name): + return name in self._registry diff --git a/openlrm/utils/scheduler.py b/openlrm/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc151d816e2787f37f9bea02b0945e06a933c01 --- /dev/null +++ b/openlrm/utils/scheduler.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from torch.optim.lr_scheduler import LRScheduler +from accelerate.logging import get_logger + + +logger = get_logger(__name__) + + +class CosineWarmupScheduler(LRScheduler): + def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1): + self.warmup_iters = warmup_iters + self.max_iters = max_iters + self.initial_lr = initial_lr + super().__init__(optimizer, last_iter) + + def get_lr(self): + logger.debug(f"step count: {self._step_count} | warmup iters: {self.warmup_iters} | max iters: {self.max_iters}") + if self._step_count <= self.warmup_iters: + return [ + self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters + for base_lr in self.base_lrs] + else: + cos_iter = self._step_count - self.warmup_iters + cos_max_iter = self.max_iters - self.warmup_iters + cos_theta = cos_iter / cos_max_iter * math.pi + cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs] + return cos_lr diff --git a/openlrm/utils/video.py b/openlrm/utils/video.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb79d5760773be29f55bb313f378bfe13ecb320 --- /dev/null +++ b/openlrm/utils/video.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import numpy as np +import imageio + + +def images_to_video(images, output_path, fps, gradio_codec: bool, verbose=False): + # images: (T, C, H, W) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + frames = [] + for i in range(images.shape[0]): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + frames = np.stack(frames) + if gradio_codec: + imageio.mimwrite(output_path, frames, fps=fps, quality=10) + else: + imageio.mimwrite(output_path, frames, fps=fps, codec='mpeg4', quality=10) + if verbose: + print(f"Using gradio codec option {gradio_codec}") + print(f"Saved video to {output_path}") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..427ae11e2efff7a65c64ff648771b1121018699d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +lpips +omegaconf +transformers +safetensors +accelerate +imageio[ffmpeg] +PyMCubes +trimesh +megfile +opencv-python +optimum[onnxruntime-gpu] +rembg[gpu,cli] +httpx[socks] +ninja +xformers +git+https://github.com/Baijiong-Lin/LoRA-Torch \ No newline at end of file