diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa68b20cfdcf3a43c73c22f1eb9b36f61166936f
--- /dev/null
+++ b/app.py
@@ -0,0 +1,284 @@
+import os
+import torch
+import sys
+from demo import TrajCrafter
+import random
+import gradio as gr
+import random
+from inference import get_parser
+from datetime import datetime
+import argparse
+
+# 解析命令行参数
+
+traj_examples = [
+ ['20; -30; 0.3; 0; 0'],
+ ['0; 0; -0.3; -2; 2'],
+]
+
+# inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
+
+img_examples = [
+ ['test/videos/0-NNvgaTcVzAG0-r.mp4',2,1,'0; -30; 0.5; -2; 0',50,43],
+ ['test/videos/tUfDESZsQFhdDW9S.mp4',2,1,'0; 30; -0.4; 2; 0',50,43],
+ ['test/videos/part-2-3.mp4',2,1,'20; 40; 0.5; 2; 0',50,43],
+ ['test/videos/p7.mp4',2,1,'0; -50; 0.3; 0; 0',50,43],
+ ['test/videos/UST-fn-RvhJwMR5S.mp4',2,1,'0; -35; 0.4; 0; 0',50,43],
+]
+
+max_seed = 2 ** 31
+
+parser = get_parser() # infer_config.py
+opts = parser.parse_args() # default device: 'cuda:0'
+opts.weight_dtype = torch.bfloat16
+tmp = datetime.now().strftime("%Y%m%d_%H%M")
+opts.save_dir = f'./experiments/gradio_{tmp}'
+os.makedirs(opts.save_dir,exist_ok=True)
+test_tensor = torch.Tensor([0]).cuda()
+opts.device = str(test_tensor.device)
+
+CAMERA_MOTION_MODE = ["Basic Camera Trajectory", "Custom Camera Trajectory"]
+
+def show_traj(mode):
+ if mode == 'Orbit Left':
+ return gr.update(value='0; -30; 0; 0; 0',visible=True),gr.update(visible=False)
+ elif mode == 'Orbit Right':
+ return gr.update(value='0; 30; 0; 0; 0',visible=True),gr.update(visible=False)
+ elif mode == 'Orbit Up':
+ return gr.update(value='30; 0; 0; 0; 0',visible=True),gr.update(visible=False)
+ elif mode == 'Orbit Down':
+ return gr.update(value='-20; 0; 0; 0; 0',visible=True), gr.update(visible=False)
+ if mode == 'Pan Left':
+ return gr.update(value='0; 0; 0; -2; 0',visible=True),gr.update(visible=False)
+ elif mode == 'Pan Right':
+ return gr.update(value='0; 0; 0; 2; 0',visible=True),gr.update(visible=False)
+ elif mode == 'Pan Up':
+ return gr.update(value='0; 0; 0; 0; 2',visible=True),gr.update(visible=False)
+ elif mode == 'Pan Down':
+ return gr.update(value='0; 0; 0; 0; -2',visible=True), gr.update(visible=False)
+ elif mode == 'Zoom in':
+ return gr.update(value='0; 0; 0.5; 0; 0',visible=True), gr.update(visible=False)
+ elif mode == 'Zoom out':
+ return gr.update(value='0; 0; -0.5; 0; 0',visible=True), gr.update(visible=False)
+ elif mode == 'Customize':
+ return gr.update(value='0; 0; 0; 0; 0',visible=True), gr.update(visible=True)
+ elif mode == 'Reset':
+ return gr.update(value='0; 0; 0; 0; 0',visible=False), gr.update(visible=False)
+
+def trajcrafter_demo(opts):
+ # css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px} #random_button {max-width: 100px !important}"""
+ css = """
+ #input_img {max-width: 1024px !important}
+ #output_vid {max-width: 1024px; max-height:576px}
+ #random_button {max-width: 100px !important}
+ .generate-btn {
+ background: linear-gradient(45deg, #2196F3, #1976D2) !important;
+ border: none !important;
+ color: white !important;
+ font-weight: bold !important;
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important;
+ }
+ .generate-btn:hover {
+ background: linear-gradient(45deg, #1976D2, #1565C0) !important;
+ box-shadow: 0 4px 8px rgba(0,0,0,0.3) !important;
+ }
+ """
+ image2video = TrajCrafter(opts,gradio=True)
+ # image2video.run_both = spaces.GPU(image2video.run_both, duration=290) # fixme
+ with gr.Blocks(analytics_enabled=False, css=css) as trajcrafter_iface:
+ gr.Markdown("
TrajectoryCrafter: Redirecting View Trajectory for Monocular Videos via Diffusion Models
")
+ # #
")
+
+
+ with gr.Row(equal_height=True):
+ with gr.Column():
+ # # step 1: input an image
+ # gr.Markdown("---\n## Step 1: Input an Image, selet an elevation angle and a center_scale factor", show_label=False, visible=True)
+ # gr.Markdown("1. Estimate an elevation angle that represents the angle at which the image was taken; a value bigger than 0 indicates a top-down view, and it doesn't need to be precise.
2. The origin of the world coordinate system is by default defined at the point cloud corresponding to the center pixel of the input image. You can adjust the position of the origin by modifying center_scale; a value smaller than 1 brings the origin closer to you.
")
+ i2v_input_video = gr.Video(label="Input Video", elem_id="input_video", format="mp4")
+
+
+ with gr.Column():
+ i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True,
+ show_share_button=True)
+
+ with gr.Row():
+ with gr.Row():
+ i2v_stride = gr.Slider(minimum=1, maximum=3, step=1, elem_id="stride", label="Stride", value=1)
+ i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale",
+ label="center_scale", value=1)
+ i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps",
+ value=50)
+ i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
+ with gr.Row():
+ pan_left = gr.Button(value="Pan Left")
+ pan_right = gr.Button(value="Pan Right")
+ pan_up = gr.Button(value="Pan Up")
+ pan_down = gr.Button(value="Pan Down")
+ with gr.Row():
+ orbit_left = gr.Button(value="Orbit Left")
+ orbit_right = gr.Button(value="Orbit Right")
+ orbit_up = gr.Button(value="Orbit Up")
+ orbit_down = gr.Button(value="Orbit Down")
+ with gr.Row():
+ zin = gr.Button(value="Zoom in")
+ zout = gr.Button(value="Zoom out")
+ custom = gr.Button(value="Customize")
+ reset = gr.Button(value="Reset")
+ with gr.Column():
+ with gr.Row():
+ with gr.Column():
+ i2v_pose = gr.Text(value='0; 0; 0; 0; 0', label="Traget camera pose (theta, phi, r, x, y)",
+ visible=False)
+ with gr.Column(visible=False) as i2v_egs:
+ gr.Markdown(
+ "Please refer to
tutorial for customizing camera trajectory.
")
+ gr.Examples(examples=traj_examples,
+ inputs=[i2v_pose],
+ )
+ with gr.Column():
+ i2v_end_btn = gr.Button("Generate video", scale=2, size="lg", variant="primary", elem_classes="generate-btn")
+
+
+ # with gr.Column():
+ # i2v_input_video = gr.Video(label="Input Video", elem_id="input_video", format="mp4")
+ # i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
+ # with gr.Row():
+ # # i2v_elevation = gr.Slider(minimum=-45, maximum=45, step=1, elem_id="elevation", label="elevation", value=5)
+ # i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale", label="center_scale", value=1)
+ # i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
+ # i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
+ # with gr.Column():
+ # with gr.Row():
+ # left = gr.Button(value = "Left")
+ # right = gr.Button(value = "Right")
+ # up = gr.Button(value = "Up")
+ # with gr.Row():
+ # down = gr.Button(value = "Down")
+ # zin = gr.Button(value = "Zoom in")
+ # zout = gr.Button(value = "Zoom out")
+ # with gr.Row():
+ # custom = gr.Button(value = "Customize")
+ # reset = gr.Button(value = "Reset")
+
+
+ # step 3 - Generate video
+ # with gr.Column():
+ # gr.Markdown("---\n## Step 3: Generate video", show_label=False, visible=True)
+ # gr.Markdown(" You can reduce the sampling steps for faster inference; try different random seed if the result is not satisfying.
")
+ # i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
+ # i2v_end_btn = gr.Button("Generate video")
+ # i2v_traj_video = gr.Video(label="Camera Trajectory",elem_id="traj_vid",autoplay=True,show_share_button=True)
+
+ # with gr.Column(scale=1.5):
+ # with gr.Row():
+ # # i2v_elevation = gr.Slider(minimum=-45, maximum=45, step=1, elem_id="elevation", label="elevation", value=5)
+ # i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale", label="center_scale", value=1)
+ # i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
+ # i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
+ # with gr.Row():
+ # pan_left = gr.Button(value = "Pan Left")
+ # pan_right = gr.Button(value = "Pan Right")
+ # pan_up = gr.Button(value = "Pan Up")
+ # pan_down = gr.Button(value = "Pan Down")
+ # with gr.Row():
+ # orbit_left = gr.Button(value = "Orbit Left")
+ # orbit_right = gr.Button(value = "Orbit Right")
+ # orbit_up = gr.Button(value = "Orbit Up")
+ # orbit_down = gr.Button(value = "Orbit Down")
+ # with gr.Row():
+ # zin = gr.Button(value = "Zoom in")
+ # zout = gr.Button(value = "Zoom out")
+ # custom = gr.Button(value = "Customize")
+ # reset = gr.Button(value = "Reset")
+ # with gr.Column():
+ # with gr.Row():
+ # with gr.Column():
+ # i2v_pose = gr.Text(value = '0; 0; 0; 0; 0', label="Traget camera pose (theta, phi, r, x, y)",visible=False)
+ # with gr.Column(visible=False) as i2v_egs:
+ # gr.Markdown("Please refer to the
tutorial for customizing camera trajectory.
")
+ # gr.Examples(examples=traj_examples,
+ # inputs=[i2v_pose],
+ # )
+ # with gr.Row():
+ # i2v_end_btn = gr.Button("Generate video")
+ # step 3 - Generate video
+ # with gr.Row():
+ # with gr.Column():
+
+
+
+ i2v_end_btn.click(inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
+ outputs=[i2v_output_video],
+ fn = image2video.run_gradio
+ )
+
+ pan_left.click(inputs=[pan_left],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ pan_right.click(inputs=[pan_right],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ pan_up.click(inputs=[pan_up],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ pan_down.click(inputs=[pan_down],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ orbit_left.click(inputs=[orbit_left],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ orbit_right.click(inputs=[orbit_right],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ orbit_up.click(inputs=[orbit_up],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ orbit_down.click(inputs=[orbit_down],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ zin.click(inputs=[zin],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ zout.click(inputs=[zout],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ custom.click(inputs=[custom],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+ reset.click(inputs=[reset],
+ outputs=[i2v_pose,i2v_egs],
+ fn = show_traj
+ )
+
+
+ gr.Examples(examples=img_examples,
+ # inputs=[i2v_input_video,i2v_stride],
+ inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
+ )
+
+ return trajcrafter_iface
+
+
+trajcrafter_iface = trajcrafter_demo(opts)
+trajcrafter_iface.queue(max_size=10)
+# trajcrafter_iface.launch(server_name=args.server_name, max_threads=10, debug=True)
+trajcrafter_iface.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False, max_threads=10)
+
+
+
diff --git a/demo.py b/demo.py
new file mode 100755
index 0000000000000000000000000000000000000000..12388d6b4c61b053c3bacb6c24d914a9e6466821
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,377 @@
+import gc
+import os
+import torch
+from extern.depthcrafter.infer import DepthCrafterDemo
+# from extern.video_depth_anything.vdademo import VDADemo
+import numpy as np
+import torch
+from transformers import T5EncoderModel
+from omegaconf import OmegaConf
+from PIL import Image
+from models.crosstransformer3d import CrossTransformer3DModel
+from models.autoencoder_magvit import AutoencoderKLCogVideoX
+from models.pipeline_trajectorycrafter import TrajCrafter_Pipeline
+from models.utils import *
+from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+ PNDMScheduler)
+from transformers import AutoProcessor, Blip2ForConditionalGeneration
+
+class TrajCrafter:
+ def __init__(self, opts, gradio=False):
+ self.funwarp = Warper(device=opts.device)
+ # self.depth_estimater = VDADemo(pre_train_path=opts.pre_train_path_vda,device=opts.device)
+ self.depth_estimater = DepthCrafterDemo(unet_path=opts.unet_path,pre_train_path=opts.pre_train_path,cpu_offload=opts.cpu_offload,device=opts.device)
+ self.caption_processor = AutoProcessor.from_pretrained(opts.blip_path)
+ self.captioner = Blip2ForConditionalGeneration.from_pretrained(opts.blip_path, torch_dtype=torch.float16).to(opts.device)
+ self.setup_diffusion(opts)
+ if gradio:
+ self.opts=opts
+
+ def infer_gradual(self,opts):
+ frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
+ prompt = self.get_caption(opts,frames[opts.video_length//2])
+ # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
+ depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
+ frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
+ assert frames.shape[0] == opts.video_length
+ pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.video_length)
+ warped_images = []
+ masks = []
+ for i in tqdm(range(opts.video_length)):
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, opts.mask,twice=False)
+ warped_images.append(warped_frame2)
+ masks.append(mask2)
+ cond_video = (torch.cat(warped_images)+1.)/2.
+ cond_masks = torch.cat(masks)
+
+ frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
+ cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
+ cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
+ save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
+ save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
+ save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
+
+ frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
+ frames_ref = frames[:,:,:10,:,:]
+ cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
+ cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
+ generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
+
+ del self.depth_estimater
+ del self.caption_processor
+ del self.captioner
+ gc.collect()
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ sample = self.pipeline(
+ prompt,
+ num_frames = opts.video_length,
+ negative_prompt = opts.negative_prompt,
+ height = opts.sample_size[0],
+ width = opts.sample_size[1],
+ generator = generator,
+ guidance_scale = opts.diffusion_guidance_scale,
+ num_inference_steps = opts.diffusion_inference_steps,
+ video = cond_video,
+ mask_video = cond_masks,
+ reference = frames_ref,
+ ).videos
+ save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
+
+ viz = True
+ if viz:
+ tensor_left = frames[0].to(opts.device)
+ tensor_right = sample[0].to(opts.device)
+ interval = torch.ones(3, 49, 384, 30).to(opts.device)
+ result = torch.cat((tensor_left, interval, tensor_right), dim=3)
+ result_reverse = torch.flip(result, dims=[1])
+ final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
+ save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2)
+
+ def infer_direct(self,opts):
+ opts.cut = 20
+ frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
+ prompt = self.get_caption(opts,frames[opts.video_length//2])
+ # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
+ depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
+ frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
+ assert frames.shape[0] == opts.video_length
+ pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.cut)
+
+ warped_images = []
+ masks = []
+ for i in tqdm(range(opts.video_length)):
+ if i < opts.cut:
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[0:1], None, depths[0:1], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False)
+ warped_images.append(warped_frame2)
+ masks.append(mask2)
+ else:
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i-opts.cut:i-opts.cut+1], None, depths[i-opts.cut:i-opts.cut+1], pose_s[0:1], pose_t[-1:], K[0:1], None, opts.mask,twice=False)
+ warped_images.append(warped_frame2)
+ masks.append(mask2)
+ cond_video = (torch.cat(warped_images)+1.)/2.
+ cond_masks = torch.cat(masks)
+ frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
+ cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
+ cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
+ save_video((frames[:opts.video_length-opts.cut].permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
+ save_video(cond_video[opts.cut:].permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
+ save_video(cond_masks[opts.cut:].repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
+ frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
+ frames_ref = frames[:,:,:10,:,:]
+ cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
+ cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
+ generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
+
+ del self.depth_estimater
+ del self.caption_processor
+ del self.captioner
+ gc.collect()
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ sample = self.pipeline(
+ prompt,
+ num_frames = opts.video_length,
+ negative_prompt = opts.negative_prompt,
+ height = opts.sample_size[0],
+ width = opts.sample_size[1],
+ generator = generator,
+ guidance_scale = opts.diffusion_guidance_scale,
+ num_inference_steps = opts.diffusion_inference_steps,
+ video = cond_video,
+ mask_video = cond_masks,
+ reference = frames_ref,
+ ).videos
+ save_video(sample[0].permute(1,2,3,0)[opts.cut:], os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
+
+ viz = True
+ if viz:
+ tensor_left = frames[0][:,:opts.video_length-opts.cut,...].to(opts.device)
+ tensor_right = sample[0][:,opts.cut:,...].to(opts.device)
+ interval = torch.ones(3, opts.video_length-opts.cut, 384, 30).to(opts.device)
+ result = torch.cat((tensor_left, interval, tensor_right), dim=3)
+ result_reverse = torch.flip(result, dims=[1])
+ final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
+ save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2)
+
+ def infer_bullet(self,opts):
+ frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
+ prompt = self.get_caption(opts,frames[opts.video_length//2])
+ # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
+ depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
+
+ frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
+ assert frames.shape[0] == opts.video_length
+ pose_s, pose_t, K = self.get_poses(opts,depths, num_frames = opts.video_length)
+
+ warped_images = []
+ masks = []
+ for i in tqdm(range(opts.video_length)):
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[-1:], None, depths[-1:], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False)
+ warped_images.append(warped_frame2)
+ masks.append(mask2)
+ cond_video = (torch.cat(warped_images)+1.)/2.
+ cond_masks = torch.cat(masks)
+ frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
+ cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
+ cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
+ save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
+ save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
+ save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
+ frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
+ frames_ref = frames[:,:,-10:,:,:]
+ cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
+ cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
+ generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
+
+ del self.depth_estimater
+ del self.caption_processor
+ del self.captioner
+ gc.collect()
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ sample = self.pipeline(
+ prompt,
+ num_frames = opts.video_length,
+ negative_prompt = opts.negative_prompt,
+ height = opts.sample_size[0],
+ width = opts.sample_size[1],
+ generator = generator,
+ guidance_scale = opts.diffusion_guidance_scale,
+ num_inference_steps = opts.diffusion_inference_steps,
+ video = cond_video,
+ mask_video = cond_masks,
+ reference = frames_ref,
+ ).videos
+ save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
+
+ viz = True
+ if viz:
+ tensor_left = frames[0].to(opts.device)
+ tensor_left_full = torch.cat([tensor_left,tensor_left[:,-1:,:,:].repeat(1,48,1,1)],dim=1)
+ tensor_right = sample[0].to(opts.device)
+ tensor_right_full = torch.cat([tensor_left,tensor_right[:,1:,:,:]],dim=1)
+ interval = torch.ones(3, 49*2-1, 384, 30).to(opts.device)
+ result = torch.cat((tensor_left_full, interval, tensor_right_full), dim=3)
+ result_reverse = torch.flip(result, dims=[1])
+ final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
+ save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*4)
+
+ def get_caption(self,opts,image):
+ image_array = (image * 255).astype(np.uint8)
+ pil_image = Image.fromarray(image_array)
+ inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(opts.device, torch.float16)
+ generated_ids = self.captioner.generate(**inputs)
+ generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
+ return generated_text + opts.refine_prompt
+
+ def get_poses(self,opts,depths,num_frames):
+ radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*opts.radius_scale
+ radius = min(radius, 5)
+ cx = 512. #depths.shape[-1]//2
+ cy = 288. #depths.shape[-2]//2
+ f = 500 #500.
+ K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(opts.device)
+ c2w_init = torch.tensor([[-1., 0., 0., 0.],
+ [ 0., 1., 0., 0.],
+ [ 0., 0., -1., 0.],
+ [ 0., 0., 0., 1.]]).to(opts.device).unsqueeze(0)
+ if opts.camera == 'target':
+ dtheta, dphi, dr, dx, dy = opts.target_pose
+ poses = generate_traj_specified(c2w_init, dtheta, dphi, dr*radius, dx, dy, num_frames, opts.device)
+ elif opts.camera =='traj':
+ with open(opts.traj_txt, 'r') as file:
+ lines = file.readlines()
+ theta = [float(i) for i in lines[0].split()]
+ phi = [float(i) for i in lines[1].split()]
+ r = [float(i)*radius for i in lines[2].split()]
+ poses = generate_traj_txt(c2w_init, phi, theta, r, num_frames, opts.device)
+ poses[:,2, 3] = poses[:,2, 3] + radius
+ pose_s = poses[opts.anchor_idx:opts.anchor_idx+1].repeat(num_frames,1,1)
+ pose_t = poses
+ return pose_s, pose_t, K
+
+ def setup_diffusion(self,opts):
+ # transformer = CrossTransformer3DModel.from_pretrained_cus(opts.transformer_path).to(opts.weight_dtype)
+ transformer = CrossTransformer3DModel.from_pretrained(opts.transformer_path).to(opts.weight_dtype)
+ # transformer = transformer.to(opts.weight_dtype)
+ vae = AutoencoderKLCogVideoX.from_pretrained(
+ opts.model_name,
+ subfolder="vae"
+ ).to(opts.weight_dtype)
+ text_encoder = T5EncoderModel.from_pretrained(
+ opts.model_name, subfolder="text_encoder", torch_dtype=opts.weight_dtype
+ )
+ # Get Scheduler
+ Choosen_Scheduler = {
+ "Euler": EulerDiscreteScheduler,
+ "Euler A": EulerAncestralDiscreteScheduler,
+ "DPM++": DPMSolverMultistepScheduler,
+ "PNDM": PNDMScheduler,
+ "DDIM_Cog": CogVideoXDDIMScheduler,
+ "DDIM_Origin": DDIMScheduler,
+ }[opts.sampler_name]
+ scheduler = Choosen_Scheduler.from_pretrained(
+ opts.model_name,
+ subfolder="scheduler"
+ )
+
+ self.pipeline = TrajCrafter_Pipeline.from_pretrained(
+ opts.model_name,
+ vae=vae,
+ text_encoder=text_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ torch_dtype=opts.weight_dtype
+ )
+
+ if opts.low_gpu_memory_mode:
+ self.pipeline.enable_sequential_cpu_offload()
+ else:
+ self.pipeline.enable_model_cpu_offload()
+
+ def run_gradio(self,input_video, stride, radius_scale, pose, steps, seed):
+ frames = read_video_frames(input_video, self.opts.video_length, stride,self.opts.max_res)
+ prompt = self.get_caption(self.opts,frames[self.opts.video_length//2])
+ # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
+ depths= self.depth_estimater.infer(frames, self.opts.near, self.opts.far, self.opts.depth_inference_steps, self.opts.depth_guidance_scale, window_size=self.opts.window_size, overlap=self.opts.overlap).to(self.opts.device)
+ frames = torch.from_numpy(frames).permute(0,3,1,2).to(self.opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
+ num_frames = frames.shape[0]
+ assert num_frames == self.opts.video_length
+ radius_scale = float(radius_scale)
+ radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*radius_scale
+ radius = min(radius, 5)
+ cx = 512. #depths.shape[-1]//2
+ cy = 288. #depths.shape[-2]//2
+ f = 500 #500.
+ K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(self.opts.device)
+ c2w_init = torch.tensor([[-1., 0., 0., 0.],
+ [ 0., 1., 0., 0.],
+ [ 0., 0., -1., 0.],
+ [ 0., 0., 0., 1.]]).to(self.opts.device).unsqueeze(0)
+
+ # import pdb
+ # pdb.set_trace()
+ theta,phi,r,x,y = [float(i) for i in pose.split(';')]
+ # theta,phi,r,x,y = [float(i) for i in theta.split()],[float(i) for i in phi.split()],[float(i) for i in r.split()],[float(i) for i in x.split()],[float(i) for i in y.split()]
+ # target mode
+ poses = generate_traj_specified(c2w_init, theta, phi, r*radius, x, y, num_frames, self.opts.device)
+ poses[:,2, 3] = poses[:,2, 3] + radius
+ pose_s = poses[self.opts.anchor_idx:self.opts.anchor_idx+1].repeat(num_frames,1,1)
+ pose_t = poses
+
+ warped_images = []
+ masks = []
+ for i in tqdm(range(self.opts.video_length)):
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, self.opts.mask,twice=False)
+ warped_images.append(warped_frame2)
+ masks.append(mask2)
+ cond_video = (torch.cat(warped_images)+1.)/2.
+ cond_masks = torch.cat(masks)
+
+ frames = F.interpolate(frames, size=self.opts.sample_size, mode='bilinear', align_corners=False)
+ cond_video = F.interpolate(cond_video, size=self.opts.sample_size, mode='bilinear', align_corners=False)
+ cond_masks = F.interpolate(cond_masks, size=self.opts.sample_size, mode='nearest')
+ save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(self.opts.save_dir,'input.mp4'),fps=self.opts.fps)
+ save_video(cond_video.permute(0,2,3,1), os.path.join(self.opts.save_dir,'render.mp4'),fps=self.opts.fps)
+ save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(self.opts.save_dir,'mask.mp4'),fps=self.opts.fps)
+
+ frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
+ frames_ref = frames[:,:,:10,:,:]
+ cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
+ cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
+ generator = torch.Generator(device=self.opts.device).manual_seed(seed)
+
+ # del self.depth_estimater
+ # del self.caption_processor
+ # del self.captioner
+ # gc.collect()
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ sample = self.pipeline(
+ prompt,
+ num_frames = self.opts.video_length,
+ negative_prompt = self.opts.negative_prompt,
+ height = self.opts.sample_size[0],
+ width = self.opts.sample_size[1],
+ generator = generator,
+ guidance_scale = self.opts.diffusion_guidance_scale,
+ num_inference_steps = steps,
+ video = cond_video,
+ mask_video = cond_masks,
+ reference = frames_ref,
+ ).videos
+ save_video(sample[0].permute(1,2,3,0), os.path.join(self.opts.save_dir,'gen.mp4'), fps=self.opts.fps)
+
+ viz = True
+ if viz:
+ tensor_left = frames[0].to(self.opts.device)
+ tensor_right = sample[0].to(self.opts.device)
+ interval = torch.ones(3, 49, 384, 30).to(self.opts.device)
+ result = torch.cat((tensor_left, interval, tensor_right), dim=3)
+ result_reverse = torch.flip(result, dims=[1])
+ final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
+ save_video(final_result.permute(1,2,3,0), os.path.join(self.opts.save_dir,'viz.mp4'), fps=self.opts.fps*2)
+ return os.path.join(self.opts.save_dir,'viz.mp4')
\ No newline at end of file
diff --git a/docs/config_help.md b/docs/config_help.md
new file mode 100755
index 0000000000000000000000000000000000000000..d56f2e267939a650ff7d8a57b0dce466effbb32f
--- /dev/null
+++ b/docs/config_help.md
@@ -0,0 +1,27 @@
+## Important configuration for [inference.py](../inference.py):
+
+### 1. General configs
+| Configuration | Default Value | Explanation |
+|:----------------- |:--------------- |:-------------------------------------------------------- |
+| `--video_path` | `None` | Input video file path |
+| `--out_dir` | `./experiments/`| Output directory |
+| `--device` | `cuda:0` | The device to use (e.g., CPU or GPU) |
+| `--exp_name` | `None` | Experiment name, defaults to video file name |
+| `--seed` | `43` | Random seed for reproducibility |
+| `--video_length` | `49` | Length of the video frames (number of frames) |
+| `--fps` | `10` | fps for saved video |
+| `--stride` | `1` | Sampling stride for input video (frame interval) |
+| `--server_name` | `None` | Server IP address for gradio |
+### 2. Point cloud render configs
+
+| Configuration | Default Value | Explanation |
+|:----------------- |:--------------- |:-------------------------------------------------------- |
+| `--radius_scale` | `1.0` | Scale factor for the spherical radius |
+| `--camera` | `traj` | Camera pose type, either 'traj' or 'target' |
+| `--mode` | `gradual` | Mode of operation, 'gradual', 'bullet', or 'direct' |
+| `--mask` | `False` | Clean the point cloud data if true |
+| `--target_pose` | `None` | Required for 'target' camera pose type, specifies a relative camera pose sequece (theta, phi, r, x, y). +theta (theta<50) rotates camera upward, +phi (phi<50) rotates camera to right, +r (r<0.6) moves camera forward, +x (x<4) pans the camera to right, +y (y<4) pans the camera upward |
+| `--traj_txt` | `None` | Required for 'traj' camera pose type, a txt file specifying a complex camera trajectory ([examples](../test/trajs/loop1.txt)). The fist line is the theta sequence, the second line the phi sequence, and the last line the r sequence |
+| `--near` | `0.0001` | Near clipping plane distance |
+| `--far` | `10000.0` | Far clipping plane distance |
+| `--anchor_idx` | `0` | One GT frame for anchor frame |
diff --git a/extern/depthcrafter/__init__.py b/extern/depthcrafter/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/extern/depthcrafter/__pycache__/__init__.cpython-310.pyc b/extern/depthcrafter/__pycache__/__init__.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..75804d9a56d0c510bcbd06ec7b39c4803aed7835
Binary files /dev/null and b/extern/depthcrafter/__pycache__/__init__.cpython-310.pyc differ
diff --git a/extern/depthcrafter/__pycache__/demo.cpython-310.pyc b/extern/depthcrafter/__pycache__/demo.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..9620f9998b3b91a193e2be671d146cc827754802
Binary files /dev/null and b/extern/depthcrafter/__pycache__/demo.cpython-310.pyc differ
diff --git a/extern/depthcrafter/__pycache__/depth_crafter_ppl.cpython-310.pyc b/extern/depthcrafter/__pycache__/depth_crafter_ppl.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..6358fa135a080f6b11439704f6bfade8a38851a8
Binary files /dev/null and b/extern/depthcrafter/__pycache__/depth_crafter_ppl.cpython-310.pyc differ
diff --git a/extern/depthcrafter/__pycache__/infer.cpython-310.pyc b/extern/depthcrafter/__pycache__/infer.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..52ca79fc0f6028dd46596622c15e21ed61555367
Binary files /dev/null and b/extern/depthcrafter/__pycache__/infer.cpython-310.pyc differ
diff --git a/extern/depthcrafter/__pycache__/unet.cpython-310.pyc b/extern/depthcrafter/__pycache__/unet.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a6bdc0462e050eeee57141f3717ea76ada763748
Binary files /dev/null and b/extern/depthcrafter/__pycache__/unet.cpython-310.pyc differ
diff --git a/extern/depthcrafter/__pycache__/utils.cpython-310.pyc b/extern/depthcrafter/__pycache__/utils.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..bb97efa46b999c04d900e1a6d041c00a1fc3ed73
Binary files /dev/null and b/extern/depthcrafter/__pycache__/utils.cpython-310.pyc differ
diff --git a/extern/depthcrafter/depth_crafter_ppl.py b/extern/depthcrafter/depth_crafter_ppl.py
new file mode 100755
index 0000000000000000000000000000000000000000..b7d070d496aec9d1217aac83625878f0159a4ca2
--- /dev/null
+++ b/extern/depthcrafter/depth_crafter_ppl.py
@@ -0,0 +1,366 @@
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+
+from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
+ _resize_with_antialiasing,
+ StableVideoDiffusionPipelineOutput,
+ StableVideoDiffusionPipeline,
+ retrieve_timesteps,
+)
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class DepthCrafterPipeline(StableVideoDiffusionPipeline):
+
+ @torch.inference_mode()
+ def encode_video(
+ self,
+ video: torch.Tensor,
+ chunk_size: int = 14,
+ ) -> torch.Tensor:
+ """
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
+ :param chunk_size: the chunk size to encode video
+ :return: image_embeddings in shape of [b, 1024]
+ """
+
+ video_224 = _resize_with_antialiasing(video.float(), (224, 224))
+ video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
+
+ embeddings = []
+ for i in range(0, video_224.shape[0], chunk_size):
+ tmp = self.feature_extractor(
+ images=video_224[i : i + chunk_size],
+ do_normalize=True,
+ do_center_crop=False,
+ do_resize=False,
+ do_rescale=False,
+ return_tensors="pt",
+ ).pixel_values.to(video.device, dtype=video.dtype)
+ embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024]
+
+ embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
+ return embeddings
+
+ @torch.inference_mode()
+ def encode_vae_video(
+ self,
+ video: torch.Tensor,
+ chunk_size: int = 14,
+ ):
+ """
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
+ :param chunk_size: the chunk size to encode video
+ :return: vae latents in shape of [b, c, h, w]
+ """
+ video_latents = []
+ for i in range(0, video.shape[0], chunk_size):
+ video_latents.append(
+ self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
+ )
+ video_latents = torch.cat(video_latents, dim=0)
+ return video_latents
+
+ @staticmethod
+ def check_inputs(video, height, width):
+ """
+ :param video:
+ :param height:
+ :param width:
+ :return:
+ """
+ if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray):
+ raise ValueError(
+ f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ video: Union[np.ndarray, torch.Tensor],
+ height: int = 576,
+ width: int = 1024,
+ num_inference_steps: int = 25,
+ guidance_scale: float = 1.0,
+ window_size: Optional[int] = 110,
+ noise_aug_strength: float = 0.02,
+ decode_chunk_size: Optional[int] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ return_dict: bool = True,
+ overlap: int = 25,
+ track_time: bool = False,
+ ):
+ """
+ :param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
+ :param height:
+ :param width:
+ :param num_inference_steps:
+ :param guidance_scale:
+ :param window_size: sliding window processing size
+ :param fps:
+ :param motion_bucket_id:
+ :param noise_aug_strength:
+ :param decode_chunk_size:
+ :param generator:
+ :param latents:
+ :param output_type:
+ :param callback_on_step_end:
+ :param callback_on_step_end_tensor_inputs:
+ :param return_dict:
+ :return:
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ num_frames = video.shape[0]
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
+ if num_frames <= window_size:
+ window_size = num_frames
+ overlap = 0
+ stride = window_size - overlap
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(video, height, width)
+
+ # 2. Define call parameters
+ batch_size = 1
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ self._guidance_scale = guidance_scale
+
+ # 3. Encode input video
+ if isinstance(video, np.ndarray):
+ video = torch.from_numpy(video.transpose(0, 3, 1, 2))
+ else:
+ assert isinstance(video, torch.Tensor)
+ video = video.to(device=device, dtype=self.dtype)
+ video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
+
+ if track_time:
+ start_event = torch.cuda.Event(enable_timing=True)
+ encode_event = torch.cuda.Event(enable_timing=True)
+ denoise_event = torch.cuda.Event(enable_timing=True)
+ decode_event = torch.cuda.Event(enable_timing=True)
+ start_event.record()
+
+ video_embeddings = self.encode_video(
+ video, chunk_size=decode_chunk_size
+ ).unsqueeze(
+ 0
+ ) # [1, t, 1024]
+ torch.cuda.empty_cache()
+ # 4. Encode input image using VAE
+ noise = randn_tensor(
+ video.shape, generator=generator, device=device, dtype=video.dtype
+ )
+ video = video + noise_aug_strength * noise # in [t, c, h, w]
+
+ # pdb.set_trace()
+ needs_upcasting = (
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ )
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float32)
+
+ video_latents = self.encode_vae_video(
+ video.to(self.vae.dtype),
+ chunk_size=decode_chunk_size,
+ ).unsqueeze(
+ 0
+ ) # [1, t, c, h, w]
+
+ if track_time:
+ encode_event.record()
+ torch.cuda.synchronize()
+ elapsed_time_ms = start_event.elapsed_time(encode_event)
+ print(f"Elapsed time for encoding video: {elapsed_time_ms} ms")
+
+ torch.cuda.empty_cache()
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ # 5. Get Added Time IDs
+ added_time_ids = self._get_add_time_ids(
+ 7,
+ 127,
+ noise_aug_strength,
+ video_embeddings.dtype,
+ batch_size,
+ 1,
+ False,
+ ) # [1 or 2, 3]
+ added_time_ids = added_time_ids.to(device)
+
+ # 6. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, None, None
+ )
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ # 7. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents_init = self.prepare_latents(
+ batch_size,
+ window_size,
+ num_channels_latents,
+ height,
+ width,
+ video_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ ) # [1, t, c, h, w]
+ latents_all = None
+
+ idx_start = 0
+ if overlap > 0:
+ weights = torch.linspace(0, 1, overlap, device=device)
+ weights = weights.view(1, overlap, 1, 1, 1)
+ else:
+ weights = None
+
+ torch.cuda.empty_cache()
+
+ # inference strategy for long videos
+ # two main strategies: 1. noise init from previous frame, 2. segments stitching
+ while idx_start < num_frames - overlap:
+ idx_end = min(idx_start + window_size, num_frames)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ # 9. Denoising loop
+ latents = latents_init[:, : idx_end - idx_start].clone()
+ latents_init = torch.cat(
+ [latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
+ )
+
+ video_latents_current = video_latents[:, idx_start:idx_end]
+ video_embeddings_current = video_embeddings[:, idx_start:idx_end]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if latents_all is not None and i == 0:
+ latents[:, :overlap] = (
+ latents_all[:, -overlap:]
+ + latents[:, :overlap]
+ / self.scheduler.init_noise_sigma
+ * self.scheduler.sigmas[i]
+ )
+
+ latent_model_input = latents # [1, t, c, h, w]
+ latent_model_input = self.scheduler.scale_model_input(
+ latent_model_input, t
+ ) # [1, t, c, h, w]
+ latent_model_input = torch.cat(
+ [latent_model_input, video_latents_current], dim=2
+ )
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=video_embeddings_current,
+ added_time_ids=added_time_ids,
+ return_dict=False,
+ )[0]
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ latent_model_input = latents
+ latent_model_input = self.scheduler.scale_model_input(
+ latent_model_input, t
+ )
+ latent_model_input = torch.cat(
+ [latent_model_input, torch.zeros_like(latent_model_input)],
+ dim=2,
+ )
+ noise_pred_uncond = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=torch.zeros_like(
+ video_embeddings_current
+ ),
+ added_time_ids=added_time_ids,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
+ noise_pred - noise_pred_uncond
+ )
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(
+ self, i, t, callback_kwargs
+ )
+
+ latents = callback_outputs.pop("latents", latents)
+
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps
+ and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if latents_all is None:
+ latents_all = latents.clone()
+ else:
+ assert weights is not None
+ # latents_all[:, -overlap:] = (
+ # latents[:, :overlap] + latents_all[:, -overlap:]
+ # ) / 2.0
+ latents_all[:, -overlap:] = latents[
+ :, :overlap
+ ] * weights + latents_all[:, -overlap:] * (1 - weights)
+ latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
+
+ idx_start += stride
+
+ if track_time:
+ denoise_event.record()
+ torch.cuda.synchronize()
+ elapsed_time_ms = encode_event.elapsed_time(denoise_event)
+ print(f"Elapsed time for denoising video: {elapsed_time_ms} ms")
+
+ if not output_type == "latent":
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ frames = self.decode_latents(latents_all, num_frames, decode_chunk_size)
+
+ if track_time:
+ decode_event.record()
+ torch.cuda.synchronize()
+ elapsed_time_ms = denoise_event.elapsed_time(decode_event)
+ print(f"Elapsed time for decoding video: {elapsed_time_ms} ms")
+
+ frames = self.video_processor.postprocess_video(
+ video=frames, output_type=output_type
+ )
+ else:
+ frames = latents_all
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return frames
+
+ return StableVideoDiffusionPipelineOutput(frames=frames)
diff --git a/extern/depthcrafter/infer.py b/extern/depthcrafter/infer.py
new file mode 100755
index 0000000000000000000000000000000000000000..d939c3b9097afe4356f31cac20fe5542dca3d923
--- /dev/null
+++ b/extern/depthcrafter/infer.py
@@ -0,0 +1,91 @@
+import gc
+import os
+import numpy as np
+import torch
+
+from diffusers.training_utils import set_seed
+from extern.depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
+from extern.depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
+
+
+class DepthCrafterDemo:
+ def __init__(
+ self,
+ unet_path: str,
+ pre_train_path: str,
+ cpu_offload: str = "model",
+ device: str = "cuda:0"
+ ):
+ unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained(
+ unet_path,
+ low_cpu_mem_usage=True,
+ torch_dtype=torch.float16,
+ )
+ # load weights of other components from the provided checkpoint
+ self.pipe = DepthCrafterPipeline.from_pretrained(
+ pre_train_path,
+ unet=unet,
+ torch_dtype=torch.float16,
+ variant="fp16",
+ )
+
+ # for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
+ if cpu_offload is not None:
+ if cpu_offload == "sequential":
+ # This will slow, but save more memory
+ self.pipe.enable_sequential_cpu_offload()
+ elif cpu_offload == "model":
+ self.pipe.enable_model_cpu_offload()
+ else:
+ raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
+ else:
+ self.pipe.to(device)
+ # enable attention slicing and xformers memory efficient attention
+ try:
+ self.pipe.enable_xformers_memory_efficient_attention()
+ except Exception as e:
+ print(e)
+ print("Xformers is not enabled")
+ self.pipe.enable_attention_slicing()
+
+ def infer(
+ self,
+ frames,
+ near,
+ far,
+ num_denoising_steps: int,
+ guidance_scale: float,
+ window_size: int = 110,
+ overlap: int = 25,
+ seed: int = 42,
+ track_time: bool = True,
+ ):
+ set_seed(seed)
+
+ # inference the depth map using the DepthCrafter pipeline
+ with torch.inference_mode():
+ res = self.pipe(
+ frames,
+ height=frames.shape[1],
+ width=frames.shape[2],
+ output_type="np",
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_denoising_steps,
+ window_size=window_size,
+ overlap=overlap,
+ track_time=track_time,
+ ).frames[0]
+ # convert the three-channel output to a single channel depth map
+ res = res.sum(-1) / res.shape[-1]
+ # normalize the depth map to [0, 1] across the whole video
+ depths = (res - res.min()) / (res.max() - res.min())
+ # visualize the depth map and save the results
+ # vis = vis_sequence_depth(res)
+ # save the depth map and visualization with the target FPS
+ depths = torch.from_numpy(depths).unsqueeze(1) # 49 576 1024 ->
+ depths *= 3900 # compatible with da output
+ depths[depths < 1e-5] = 1e-5
+ depths = 10000. / depths
+ depths = depths.clip(near, far)
+
+ return depths
\ No newline at end of file
diff --git a/extern/depthcrafter/unet.py b/extern/depthcrafter/unet.py
new file mode 100755
index 0000000000000000000000000000000000000000..0066a71c7a054d2e729f45baacc3a223276c1f44
--- /dev/null
+++ b/extern/depthcrafter/unet.py
@@ -0,0 +1,142 @@
+from typing import Union, Tuple
+
+import torch
+from diffusers import UNetSpatioTemporalConditionModel
+from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
+
+
+class DiffusersUNetSpatioTemporalConditionModelDepthCrafter(
+ UNetSpatioTemporalConditionModel
+):
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ return_dict: bool = True,
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
+
+ emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
+
+ # 2. pre-process
+ sample = sample.to(dtype=self.conv_in.weight.dtype)
+ assert sample.dtype == self.conv_in.weight.dtype, (
+ f"sample.dtype: {sample.dtype}, "
+ f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
+ )
+ sample = self.conv_in(sample)
+
+ image_only_indicator = torch.zeros(
+ batch_size, num_frames, dtype=sample.dtype, device=sample.device
+ )
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if (
+ hasattr(downsample_block, "has_cross_attention")
+ and downsample_block.has_cross_attention
+ ):
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[
+ : -len(upsample_block.resnets)
+ ]
+
+ if (
+ hasattr(upsample_block, "has_cross_attention")
+ and upsample_block.has_cross_attention
+ ):
+ sample = upsample_block(
+ hidden_states=sample,
+ res_hidden_states_tuple=res_samples,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ res_hidden_states_tuple=res_samples,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # 7. Reshape back to original shape
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
+
+ if not return_dict:
+ return (sample,)
+
+ return UNetSpatioTemporalConditionOutput(sample=sample)
diff --git a/extern/video_depth_anything/__pycache__/dinov2.cpython-310.pyc b/extern/video_depth_anything/__pycache__/dinov2.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..58680c97a2cda5b68b101bf1490b4607561b77c2
Binary files /dev/null and b/extern/video_depth_anything/__pycache__/dinov2.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/__pycache__/dpt.cpython-310.pyc b/extern/video_depth_anything/__pycache__/dpt.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..c1fb8ba06dd09f083c8e3abde22da8102b058d16
Binary files /dev/null and b/extern/video_depth_anything/__pycache__/dpt.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/__pycache__/dpt_temporal.cpython-310.pyc b/extern/video_depth_anything/__pycache__/dpt_temporal.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..4b2634a1549fabe7ce9e7ebfb1f581dc2d45f7d2
Binary files /dev/null and b/extern/video_depth_anything/__pycache__/dpt_temporal.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/__pycache__/vdademo.cpython-310.pyc b/extern/video_depth_anything/__pycache__/vdademo.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..eb26da54605334de6515e57b3e670bf723432d35
Binary files /dev/null and b/extern/video_depth_anything/__pycache__/vdademo.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/__pycache__/video_depth.cpython-310.pyc b/extern/video_depth_anything/__pycache__/video_depth.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..a84584c55418939c5a7e695b3eed974c3a6b34af
Binary files /dev/null and b/extern/video_depth_anything/__pycache__/video_depth.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2.py b/extern/video_depth_anything/dinov2.py
new file mode 100755
index 0000000000000000000000000000000000000000..83d250818c721c6df3b30d3f4352945527701615
--- /dev/null
+++ b/extern/video_depth_anything/dinov2.py
@@ -0,0 +1,415 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+ # w0, h0 = w0 + 0.1, h0 + 0.1
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
+ mode="bicubic",
+ antialias=self.interpolate_antialias
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def DINOv2(model_name):
+ model_zoo = {
+ "vits": vit_small,
+ "vitb": vit_base,
+ "vitl": vit_large,
+ "vitg": vit_giant2
+ }
+
+ return model_zoo[model_name](
+ img_size=518,
+ patch_size=14,
+ init_values=1.0,
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
+ block_chunks=0,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1
+ )
diff --git a/extern/video_depth_anything/dinov2_layers/__init__.py b/extern/video_depth_anything/dinov2_layers/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/extern/video_depth_anything/dinov2_layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/extern/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-310.pyc b/extern/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..daa372e3bda3e94d05b53b959b2a413fbf195376
Binary files /dev/null and b/extern/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-310.pyc b/extern/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..7f878f785caec333e3df495edd5ba2d3a42514f7
Binary files /dev/null and b/extern/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2_layers/__pycache__/block.cpython-310.pyc b/extern/video_depth_anything/dinov2_layers/__pycache__/block.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..dc5c75c31c06b168fb3b2c19eff20991515ebf27
Binary files /dev/null and b/extern/video_depth_anything/dinov2_layers/__pycache__/block.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-310.pyc b/extern/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..6259f0ddd4e4f8480a3b0654bcc3f5235e79635c
Binary files /dev/null and b/extern/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc b/extern/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..5aa1912c8c95f62c631bcff468b58faa356be998
Binary files /dev/null and b/extern/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-310.pyc b/extern/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..dfa5ee1ff9abd52bb52bf9c30f13c2854cb4e1f9
Binary files /dev/null and b/extern/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc b/extern/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..bafef74bc3d2a961382606792dc8f01c9e40683e
Binary files /dev/null and b/extern/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc b/extern/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..d493ec1db8ab82b04c3a87daab5cc327b3e33921
Binary files /dev/null and b/extern/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/dinov2_layers/attention.py b/extern/video_depth_anything/dinov2_layers/attention.py
new file mode 100755
index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262
--- /dev/null
+++ b/extern/video_depth_anything/dinov2_layers/attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
\ No newline at end of file
diff --git a/extern/video_depth_anything/dinov2_layers/block.py b/extern/video_depth_anything/dinov2_layers/block.py
new file mode 100755
index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1
--- /dev/null
+++ b/extern/video_depth_anything/dinov2_layers/block.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/extern/video_depth_anything/dinov2_layers/drop_path.py b/extern/video_depth_anything/dinov2_layers/drop_path.py
new file mode 100755
index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1
--- /dev/null
+++ b/extern/video_depth_anything/dinov2_layers/drop_path.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/extern/video_depth_anything/dinov2_layers/layer_scale.py b/extern/video_depth_anything/dinov2_layers/layer_scale.py
new file mode 100755
index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1
--- /dev/null
+++ b/extern/video_depth_anything/dinov2_layers/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/extern/video_depth_anything/dinov2_layers/mlp.py b/extern/video_depth_anything/dinov2_layers/mlp.py
new file mode 100755
index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018
--- /dev/null
+++ b/extern/video_depth_anything/dinov2_layers/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/extern/video_depth_anything/dinov2_layers/patch_embed.py b/extern/video_depth_anything/dinov2_layers/patch_embed.py
new file mode 100755
index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779
--- /dev/null
+++ b/extern/video_depth_anything/dinov2_layers/patch_embed.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/extern/video_depth_anything/dinov2_layers/swiglu_ffn.py b/extern/video_depth_anything/dinov2_layers/swiglu_ffn.py
new file mode 100755
index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e
--- /dev/null
+++ b/extern/video_depth_anything/dinov2_layers/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/extern/video_depth_anything/dpt.py b/extern/video_depth_anything/dpt.py
new file mode 100755
index 0000000000000000000000000000000000000000..8c43a1447799cc3550e7e5bba71577f7c5aec9e6
--- /dev/null
+++ b/extern/video_depth_anything/dpt.py
@@ -0,0 +1,160 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .util.blocks import FeatureFusionBlock, _make_scratch
+
+
+def _make_fusion_block(features, use_bn, size=None):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_feature, out_feature):
+ super().__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_feature),
+ nn.ReLU(True)
+ )
+
+ def forward(self, x):
+ return self.conv_block(x)
+
+
+class DPTHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ features=256,
+ use_bn=False,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False
+ ):
+ super(DPTHead, self).__init__()
+
+ self.use_clstoken = use_clstoken
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True),
+ nn.Identity(),
+ )
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv1(path_1)
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+ out = self.scratch.output_conv2(out)
+
+ return out
+
\ No newline at end of file
diff --git a/extern/video_depth_anything/dpt_temporal.py b/extern/video_depth_anything/dpt_temporal.py
new file mode 100755
index 0000000000000000000000000000000000000000..d9a27a64ec826cbff174abd3a22e6816039b2f34
--- /dev/null
+++ b/extern/video_depth_anything/dpt_temporal.py
@@ -0,0 +1,96 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from .dpt import DPTHead
+from .motion_module.motion_module import TemporalModule
+from easydict import EasyDict
+
+
+class DPTHeadTemporal(DPTHead):
+ def __init__(self,
+ in_channels,
+ features=256,
+ use_bn=False,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False,
+ num_frames=32,
+ pe='ape'
+ ):
+ super().__init__(in_channels, features, use_bn, out_channels, use_clstoken)
+
+ assert num_frames > 0
+ motion_module_kwargs = EasyDict(num_attention_heads = 8,
+ num_transformer_block = 1,
+ num_attention_blocks = 2,
+ temporal_max_len = num_frames,
+ zero_initialize = True,
+ pos_embedding_type = pe)
+
+ self.motion_modules = nn.ModuleList([
+ TemporalModule(in_channels=out_channels[2],
+ **motion_module_kwargs),
+ TemporalModule(in_channels=out_channels[3],
+ **motion_module_kwargs),
+ TemporalModule(in_channels=features,
+ **motion_module_kwargs),
+ TemporalModule(in_channels=features,
+ **motion_module_kwargs)
+ ])
+
+ def forward(self, out_features, patch_h, patch_w, frame_length):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous()
+
+ B, T = x.shape[0] // frame_length, frame_length
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ B, T = layer_1.shape[0] // frame_length, frame_length
+
+ layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
+ layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv1(path_1)
+ out = F.interpolate(
+ out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
+ )
+ out = self.scratch.output_conv2(out)
+
+ return out
\ No newline at end of file
diff --git a/extern/video_depth_anything/motion_module/__pycache__/attention.cpython-310.pyc b/extern/video_depth_anything/motion_module/__pycache__/attention.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..84860e3d7b4d2fd37d13d3f0664d635d655939a8
Binary files /dev/null and b/extern/video_depth_anything/motion_module/__pycache__/attention.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/motion_module/__pycache__/motion_module.cpython-310.pyc b/extern/video_depth_anything/motion_module/__pycache__/motion_module.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..45294367830aa5e50b27aeb92466aea0b928dfb1
Binary files /dev/null and b/extern/video_depth_anything/motion_module/__pycache__/motion_module.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/motion_module/attention.py b/extern/video_depth_anything/motion_module/attention.py
new file mode 100755
index 0000000000000000000000000000000000000000..41f551ba17e95c008f3ad1db5475c342855dd2a4
--- /dev/null
+++ b/extern/video_depth_anything/motion_module/attention.py
@@ -0,0 +1,429 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ print("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.upcast_efficient_attention = False
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous()
+ return tensor
+
+ def reshape_heads_to_4d(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous()
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous()
+ return tensor
+
+ def reshape_4d_to_heads(self, tensor):
+ batch_size, seq_len, head_size, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous()
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ if self.upcast_efficient_attention:
+ org_dtype = query.dtype
+ query = query.float()
+ key = key.float()
+ value = value.float()
+ if attention_mask is not None:
+ attention_mask = attention_mask.float()
+ hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask)
+
+ if self.upcast_efficient_attention:
+ hidden_states = hidden_states.to(org_dtype)
+
+ hidden_states = self.reshape_4d_to_heads(hidden_states)
+ return hidden_states
+
+ # print("Errror: no xformers")
+ # raise NotImplementedError
+
+ def _memory_efficient_attention_split(self, query, key, value, attention_mask):
+ batch_size = query.shape[0]
+ max_batch_size = 65535
+ num_batches = (batch_size + max_batch_size - 1) // max_batch_size
+ results = []
+ for i in range(num_batches):
+ start_idx = i * max_batch_size
+ end_idx = min((i + 1) * max_batch_size, batch_size)
+ query_batch = query[start_idx:end_idx]
+ key_batch = key[start_idx:end_idx]
+ value_batch = value[start_idx:end_idx]
+ if attention_mask is not None:
+ attention_mask_batch = attention_mask[start_idx:end_idx]
+ else:
+ attention_mask_batch = None
+ result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch)
+ results.append(result)
+ full_result = torch.cat(results, dim=0)
+ return full_result
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out))
+
+ def forward(self, hidden_states):
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+# feedforward
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ """
+ The approximate form of Gaussian Error Linear Unit (GELU)
+
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def forward(self, x):
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
+ return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous())
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous())
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
diff --git a/extern/video_depth_anything/motion_module/motion_module.py b/extern/video_depth_anything/motion_module/motion_module.py
new file mode 100755
index 0000000000000000000000000000000000000000..bbb19e225dee71023f8b5f05954cf22126a23841
--- /dev/null
+++ b/extern/video_depth_anything/motion_module/motion_module.py
@@ -0,0 +1,297 @@
+# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff
+# SPDX-License-Identifier: Apache-2.0 license
+#
+# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
+# Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme].
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis
+
+from einops import rearrange, repeat
+import math
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ print("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+def zero_module(module):
+ # Zero out the parameters of a module and return it.
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+class TemporalModule(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads = 8,
+ num_transformer_block = 2,
+ num_attention_blocks = 2,
+ norm_num_groups = 32,
+ temporal_max_len = 32,
+ zero_initialize = True,
+ pos_embedding_type = "ape",
+ ):
+ super().__init__()
+
+ self.temporal_transformer = TemporalTransformer3DModel(
+ in_channels=in_channels,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=in_channels // num_attention_heads,
+ num_layers=num_transformer_block,
+ num_attention_blocks=num_attention_blocks,
+ norm_num_groups=norm_num_groups,
+ temporal_max_len=temporal_max_len,
+ pos_embedding_type=pos_embedding_type,
+ )
+
+ if zero_initialize:
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
+
+ def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):
+ hidden_states = input_tensor
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
+
+ output = hidden_states
+ return output
+
+
+class TemporalTransformer3DModel(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads,
+ attention_head_dim,
+ num_layers,
+ num_attention_blocks = 2,
+ norm_num_groups = 32,
+ temporal_max_len = 32,
+ pos_embedding_type = "ape",
+ ):
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TemporalTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_attention_blocks=num_attention_blocks,
+ temporal_max_len=temporal_max_len,
+ pos_embedding_type=pos_embedding_type,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+
+ batch, channel, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous()
+ hidden_states = self.proj_in(hidden_states)
+
+ # Transformer Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)
+
+ # output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+
+ return output
+
+
+class TemporalTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_attention_heads,
+ attention_head_dim,
+ num_attention_blocks = 2,
+ temporal_max_len = 32,
+ pos_embedding_type = "ape",
+ ):
+ super().__init__()
+
+ self.attention_blocks = nn.ModuleList(
+ [
+ TemporalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ temporal_max_len=temporal_max_len,
+ pos_embedding_type=pos_embedding_type,
+ )
+ for i in range(num_attention_blocks)
+ ]
+ )
+ self.norms = nn.ModuleList(
+ [
+ nn.LayerNorm(dim)
+ for i in range(num_attention_blocks)
+ ]
+ )
+
+ self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu")
+ self.ff_norm = nn.LayerNorm(dim)
+
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
+ norm_hidden_states = norm(hidden_states)
+ hidden_states = attention_block(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ video_length=video_length,
+ attention_mask=attention_mask,
+ ) + hidden_states
+
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
+
+ output = hidden_states
+ return output
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ dropout = 0.,
+ max_len = 32
+ ):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+ pe = torch.zeros(1, max_len, d_model)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1)].to(x.dtype)
+ return self.dropout(x)
+
+class TemporalAttention(CrossAttention):
+ def __init__(
+ self,
+ temporal_max_len = 32,
+ pos_embedding_type = "ape",
+ *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.pos_embedding_type = pos_embedding_type
+ self._use_memory_efficient_attention_xformers = True
+
+ self.pos_encoder = None
+ self.freqs_cis = None
+ if self.pos_embedding_type == "ape":
+ self.pos_encoder = PositionalEncoding(
+ kwargs["query_dim"],
+ dropout=0.,
+ max_len=temporal_max_len
+ )
+
+ elif self.pos_embedding_type == "rope":
+ self.freqs_cis = precompute_freqs_cis(
+ kwargs["query_dim"],
+ temporal_max_len
+ )
+
+ else:
+ raise NotImplementedError
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+
+ if self.pos_encoder is not None:
+ hidden_states = self.pos_encoder(hidden_states)
+
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if self.freqs_cis is not None:
+ seq_len = query.shape[1]
+ freqs_cis = self.freqs_cis[:seq_len].to(query.device)
+ query, key = apply_rotary_emb(query, key, freqs_cis)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+
+ use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers
+ if use_memory_efficient and (dim // self.heads) % 8 != 0:
+ # print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads))
+ use_memory_efficient = False
+
+ # attention, what we cannot get enough of
+ if use_memory_efficient:
+ query = self.reshape_heads_to_4d(query)
+ key = self.reshape_heads_to_4d(key)
+ value = self.reshape_heads_to_4d(value)
+
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ query = self.reshape_heads_to_batch_dim(query)
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ raise NotImplementedError
+ # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
diff --git a/extern/video_depth_anything/util/__pycache__/blocks.cpython-310.pyc b/extern/video_depth_anything/util/__pycache__/blocks.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..4236c687fa2940c4841519f4507d3c003206e0de
Binary files /dev/null and b/extern/video_depth_anything/util/__pycache__/blocks.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/util/__pycache__/transform.cpython-310.pyc b/extern/video_depth_anything/util/__pycache__/transform.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..431510bef05971c2b6424fea8c7f57705e4ebecd
Binary files /dev/null and b/extern/video_depth_anything/util/__pycache__/transform.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/util/__pycache__/util.cpython-310.pyc b/extern/video_depth_anything/util/__pycache__/util.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..58f016a048c1bddb7ca8c6b1c365924e3bd53886
Binary files /dev/null and b/extern/video_depth_anything/util/__pycache__/util.cpython-310.pyc differ
diff --git a/extern/video_depth_anything/util/blocks.py b/extern/video_depth_anything/util/blocks.py
new file mode 100755
index 0000000000000000000000000000000000000000..0be16c053d8ae65bf33c5221351add16ecd6fb15
--- /dev/null
+++ b/extern/video_depth_anything/util/blocks.py
@@ -0,0 +1,162 @@
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups = 1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn is True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn is True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn is True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups = 1
+
+ self.expand = expand
+ out_features = features
+ if self.expand is True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1
+ )
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(
+ output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/extern/video_depth_anything/util/transform.py b/extern/video_depth_anything/util/transform.py
new file mode 100755
index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73
--- /dev/null
+++ b/extern/video_depth_anything/util/transform.py
@@ -0,0 +1,158 @@
+import numpy as np
+import cv2
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
+
+ # resize sample
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
+
+ if self.__resize_target:
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
+
+ if "mask" in sample:
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ return sample
\ No newline at end of file
diff --git a/extern/video_depth_anything/util/util.py b/extern/video_depth_anything/util/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..75ff80a841f2e30cb182a17c7a0e9e76ec062025
--- /dev/null
+++ b/extern/video_depth_anything/util/util.py
@@ -0,0 +1,74 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+
+def compute_scale_and_shift(prediction, target, mask, scale_only=False):
+ if scale_only:
+ return compute_scale(prediction, target, mask), 0
+ else:
+ return compute_scale_and_shift_full(prediction, target, mask)
+
+
+def compute_scale(prediction, target, mask):
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
+ prediction = prediction.astype(np.float32)
+ target = target.astype(np.float32)
+ mask = mask.astype(np.float32)
+
+ a_00 = np.sum(mask * prediction * prediction)
+ a_01 = np.sum(mask * prediction)
+ a_11 = np.sum(mask)
+
+ # right hand side: b = [b_0, b_1]
+ b_0 = np.sum(mask * prediction * target)
+
+ x_0 = b_0 / (a_00 + 1e-6)
+
+ return x_0
+
+def compute_scale_and_shift_full(prediction, target, mask):
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
+ prediction = prediction.astype(np.float32)
+ target = target.astype(np.float32)
+ mask = mask.astype(np.float32)
+
+ a_00 = np.sum(mask * prediction * prediction)
+ a_01 = np.sum(mask * prediction)
+ a_11 = np.sum(mask)
+
+ b_0 = np.sum(mask * prediction * target)
+ b_1 = np.sum(mask * target)
+
+ x_0 = 1
+ x_1 = 0
+
+ det = a_00 * a_11 - a_01 * a_01
+
+ if det != 0:
+ x_0 = (a_11 * b_0 - a_01 * b_1) / det
+ x_1 = (-a_01 * b_0 + a_00 * b_1) / det
+
+ return x_0, x_1
+
+
+def get_interpolate_frames(frame_list_pre, frame_list_post):
+ assert len(frame_list_pre) == len(frame_list_post)
+ min_w = 0.0
+ max_w = 1.0
+ step = (max_w - min_w) / (len(frame_list_pre)-1)
+ post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]
+ interpolated_frames = []
+ for i in range(len(frame_list_pre)):
+ interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])
+ return interpolated_frames
\ No newline at end of file
diff --git a/extern/video_depth_anything/vdademo.py b/extern/video_depth_anything/vdademo.py
new file mode 100755
index 0000000000000000000000000000000000000000..2f7473e6207ea1576c05f3e1a42a6f1f670525c9
--- /dev/null
+++ b/extern/video_depth_anything/vdademo.py
@@ -0,0 +1,63 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import numpy as np
+import os
+import torch
+from extern.video_depth_anything.video_depth import VideoDepthAnything
+
+class VDADemo:
+ def __init__(
+ self,
+ pre_train_path: str,
+ encoder: str = "vitl",
+ device: str = "cuda:0",
+ ):
+
+ model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
+ }
+
+ self.video_depth_anything = VideoDepthAnything(**model_configs[encoder])
+ self.video_depth_anything.load_state_dict(torch.load(pre_train_path, map_location='cpu'), strict=True)
+ self.video_depth_anything = self.video_depth_anything.to(device).eval()
+ self.device = device
+
+ def infer(
+ self,
+ frames,
+ near,
+ far,
+ input_size = 518,
+ target_fps = -1,
+ ):
+ if frames.max() < 2.:
+ frames = frames*255.
+
+ with torch.inference_mode():
+ depths, fps = self.video_depth_anything.infer_video_depth(frames, target_fps, input_size, self.device)
+
+ depths = torch.from_numpy(depths).unsqueeze(1) # 49 576 1024 ->
+ depths[depths < 1e-5] = 1e-5
+ depths = 10000. / depths
+ depths = depths.clip(near, far)
+
+
+ return depths
+
+
+
+
+
diff --git a/extern/video_depth_anything/video_depth.py b/extern/video_depth_anything/video_depth.py
new file mode 100755
index 0000000000000000000000000000000000000000..c4e23efbac9011cb0e9ddcc63a5e40570c942416
--- /dev/null
+++ b/extern/video_depth_anything/video_depth.py
@@ -0,0 +1,154 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torchvision.transforms import Compose
+import cv2
+from tqdm import tqdm
+import numpy as np
+import gc
+
+from extern.video_depth_anything.dinov2 import DINOv2
+from extern.video_depth_anything.dpt_temporal import DPTHeadTemporal
+from extern.video_depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
+
+from extern.video_depth_anything.util.util import compute_scale_and_shift, get_interpolate_frames
+
+# infer settings, do not change
+INFER_LEN = 32
+OVERLAP = 10
+KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
+INTERP_LEN = 8
+
+class VideoDepthAnything(nn.Module):
+ def __init__(
+ self,
+ encoder='vitl',
+ features=256,
+ out_channels=[256, 512, 1024, 1024],
+ use_bn=False,
+ use_clstoken=False,
+ num_frames=32,
+ pe='ape'
+ ):
+ super(VideoDepthAnything, self).__init__()
+
+ self.intermediate_layer_idx = {
+ 'vits': [2, 5, 8, 11],
+ 'vitl': [4, 11, 17, 23]
+ }
+
+ self.encoder = encoder
+ self.pretrained = DINOv2(model_name=encoder)
+
+ self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)
+
+ def forward(self, x):
+ B, T, C, H, W = x.shape
+ patch_h, patch_w = H // 14, W // 14
+ features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
+ depth = self.head(features, patch_h, patch_w, T)
+ depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
+ depth = F.relu(depth)
+ return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
+
+ def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
+ frame_height, frame_width = frames[0].shape[:2]
+ ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
+ if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation
+ input_size = int(input_size * 1.777 / ratio)
+ input_size = round(input_size / 14) * 14
+
+ transform = Compose([
+ Resize(
+ width=input_size,
+ height=input_size,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method='lower_bound',
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ])
+
+ frame_list = [frames[i] for i in range(frames.shape[0])]
+ frame_step = INFER_LEN - OVERLAP
+ org_video_len = len(frame_list)
+ append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
+ frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
+
+ depth_list = []
+ pre_input = None
+ for frame_id in tqdm(range(0, org_video_len, frame_step)):
+ cur_list = []
+ for i in range(INFER_LEN):
+ cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))
+ cur_input = torch.cat(cur_list, dim=1).to(device)
+ if pre_input is not None:
+ cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]
+
+ with torch.no_grad():
+ depth = self.forward(cur_input) # depth shape: [1, T, H, W]
+
+ depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)
+ depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
+
+ pre_input = cur_input
+
+ del frame_list
+ gc.collect()
+
+ depth_list_aligned = []
+ ref_align = []
+ align_len = OVERLAP - INTERP_LEN
+ kf_align_list = KEYFRAMES[:align_len]
+
+ for frame_id in range(0, len(depth_list), INFER_LEN):
+ if len(depth_list_aligned) == 0:
+ depth_list_aligned += depth_list[:INFER_LEN]
+ for kf_id in kf_align_list:
+ ref_align.append(depth_list[frame_id+kf_id])
+ else:
+ curr_align = []
+ for i in range(len(kf_align_list)):
+ curr_align.append(depth_list[frame_id+i])
+ scale, shift = compute_scale_and_shift(np.concatenate(curr_align),
+ np.concatenate(ref_align),
+ np.concatenate(np.ones_like(ref_align)==1))
+
+ pre_depth_list = depth_list_aligned[-INTERP_LEN:]
+ post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]
+ for i in range(len(post_depth_list)):
+ post_depth_list[i] = post_depth_list[i] * scale + shift
+ post_depth_list[i][post_depth_list[i]<0] = 0
+ depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)
+
+ for i in range(OVERLAP, INFER_LEN):
+ new_depth = depth_list[frame_id+i] * scale + shift
+ new_depth[new_depth<0] = 0
+ depth_list_aligned.append(new_depth)
+
+ ref_align = ref_align[:1]
+ for kf_id in kf_align_list[1:]:
+ new_depth = depth_list[frame_id+kf_id] * scale + shift
+ new_depth[new_depth<0] = 0
+ ref_align.append(new_depth)
+
+ depth_list = depth_list_aligned
+
+ return np.stack(depth_list[:org_video_len], axis=0), target_fps
+
\ No newline at end of file
diff --git a/inference.py b/inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..caa61580aed2699ccc9d0ccfa9780ac6fcad2bb1
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,77 @@
+from demo import TrajCrafter
+import os
+from datetime import datetime
+import argparse
+import torch
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+
+ ## general
+ parser.add_argument('--video_path',type=str, help='Input path')
+ parser.add_argument('--out_dir',type=str,default='./experiments/', help='Output dir')
+ parser.add_argument('--device', type=str, default='cuda:0', help='The device to use')
+ parser.add_argument('--exp_name', type=str, default=None, help='Experiment name, use video file name by default')
+ parser.add_argument('--seed', type=int, default=43, help='Random seed for reproducibility')
+ parser.add_argument('--video_length', type=int, default=49, help='Length of the video frames')
+ parser.add_argument('--fps', type=int, default=10, help='Fps for saved video')
+ parser.add_argument('--stride', type=int, default=1, help='Sampling stride for input video')
+ parser.add_argument('--server_name', type=str, help='Server IP address')
+
+ ## render
+ parser.add_argument('--radius_scale',type=float,default=1.0 , help='Scale factor for the spherical radius')
+ parser.add_argument('--camera',type=str,default='traj', help='traj or target' )
+ parser.add_argument('--mode',type=str,default='gradual', help='gradual, bullet or direct' )
+ parser.add_argument('--mask',action='store_true',default=False, help='Clean the pcd if true' )
+ parser.add_argument('--traj_txt', type=str, help="Required for 'traj' camera, a txt file that specify camera trajectory")
+ parser.add_argument('--target_pose',nargs=5,type=float, help="Required for 'target' mode, specify target camera pose, " )
+ parser.add_argument('--near', type=float, default=0.0001, help='Near clipping plane distance')
+ parser.add_argument('--far', type=float, default=10000.0, help='Far clipping plane distance')
+ parser.add_argument('--anchor_idx', type=int, default=0, help='One GT frame')
+
+ ## diffusion
+ parser.add_argument('--low_gpu_memory_mode', type=bool, default=False, help='Enable low GPU memory mode')
+ # parser.add_argument('--model_name', type=str, default='checkpoints/CogVideoX-Fun-V1.1-5b-InP', help='Path to the model')
+ parser.add_argument('--model_name', type=str, default='alibaba-pai/CogVideoX-Fun-V1.1-5b-InP', help='Path to the model')
+ parser.add_argument('--sampler_name', type=str, choices=["Euler", "Euler A", "DPM++", "PNDM", "DDIM_Cog", "DDIM_Origin"], default='DDIM_Origin', help='Choose the sampler')
+ # parser.add_argument('--transformer_path', type=str, default='checkpoints/TrajectoryCrafter/crosstransformer', help='Path to the pretrained transformer model')
+ parser.add_argument('--transformer_path', type=str, default="TrajectoryCrafter/TrajectoryCrafter", help='Path to the pretrained transformer model')
+ parser.add_argument('--sample_size', type=int, nargs=2, default=[384, 672], help='Sample size as [height, width]')
+ parser.add_argument('--diffusion_guidance_scale', type=float, default=6.0, help='Guidance scale for inference')
+ parser.add_argument('--diffusion_inference_steps', type=int, default=50, help='Number of inference steps')
+ parser.add_argument('--prompt', type=str, default=None, help='Prompt for video generation')
+ parser.add_argument('--negative_prompt', type=str, default="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion.", help='Negative prompt for video generation')
+ parser.add_argument('--refine_prompt', type=str, default=". The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", help='Prompt for video generation')
+ parser.add_argument('--blip_path',type=str,default="Salesforce/blip2-opt-2.7b")
+
+ ## depth
+ # parser.add_argument('--unet_path', type=str, default='checkpoints/DepthCrafter', help='Path to the UNet model')
+ parser.add_argument('--unet_path', type=str, default="tencent/DepthCrafter", help='Path to the UNet model')
+ parser.add_argument('--pre_train_path_vda', type=str, default='checkpoints/video_depth_anything_vitl.pth', help='Path to the pre-trained model')
+ # parser.add_argument('--pre_train_path', type=str, default='checkpoints/stable-video-diffusion-img2vid-xt', help='Path to the pre-trained model')
+ parser.add_argument('--pre_train_path', type=str, default="stabilityai/stable-video-diffusion-img2vid-xt", help='Path to the pre-trained model')
+ parser.add_argument('--cpu_offload', type=str, default='model', help='CPU offload strategy')
+ parser.add_argument('--depth_inference_steps', type=int, default=5, help='Number of inference steps')
+ parser.add_argument('--depth_guidance_scale', type=float, default=1.0, help='Guidance scale for inference')
+ parser.add_argument('--window_size', type=int, default=110, help='Window size for processing')
+ parser.add_argument('--overlap', type=int, default=25, help='Overlap size for processing')
+ parser.add_argument('--max_res', type=int, default=1024, help='Maximum resolution for processing')
+
+ return parser
+
+if __name__=="__main__":
+ parser = get_parser() # infer config.py
+ opts = parser.parse_args()
+ opts.weight_dtype = torch.bfloat16
+ if opts.exp_name == None:
+ prefix = datetime.now().strftime("%Y%m%d_%H%M")
+ opts.exp_name = f'{prefix}_{os.path.splitext(os.path.basename(opts.video_path))[0]}'
+ opts.save_dir = os.path.join(opts.out_dir,opts.exp_name)
+ os.makedirs(opts.save_dir,exist_ok=True)
+ pvd = TrajCrafter(opts)
+ if opts.mode == 'gradual':
+ pvd.infer_gradual(opts)
+ elif opts.mode == 'direct':
+ pvd.infer_direct(opts)
+ elif opts.mode == 'bullet':
+ pvd.infer_bullet(opts)
diff --git a/models/__pycache__/autoencoder_magvit.cpython-310.pyc b/models/__pycache__/autoencoder_magvit.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..97e58012e3d4357f50081f957e75b5b407b2fa71
Binary files /dev/null and b/models/__pycache__/autoencoder_magvit.cpython-310.pyc differ
diff --git a/models/__pycache__/crosstransformer3d.cpython-310.pyc b/models/__pycache__/crosstransformer3d.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..241f1ba267a97e8abc904e271d159a81ffef3dcb
Binary files /dev/null and b/models/__pycache__/crosstransformer3d.cpython-310.pyc differ
diff --git a/models/__pycache__/pipeline_trajectorycrafter.cpython-310.pyc b/models/__pycache__/pipeline_trajectorycrafter.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..ee280afe2af000286263444a99530d9c7a965638
Binary files /dev/null and b/models/__pycache__/pipeline_trajectorycrafter.cpython-310.pyc differ
diff --git a/models/__pycache__/pipeline_viewcrafter4d.cpython-310.pyc b/models/__pycache__/pipeline_viewcrafter4d.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..f663f8a2bb263e97091d7a7a374e5cbd237f0109
Binary files /dev/null and b/models/__pycache__/pipeline_viewcrafter4d.cpython-310.pyc differ
diff --git a/models/__pycache__/utils.cpython-310.pyc b/models/__pycache__/utils.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..3f00f1af7267ef7b8ba0efd8a14d51bf055536ba
Binary files /dev/null and b/models/__pycache__/utils.cpython-310.pyc differ
diff --git a/models/__pycache__/utils.cpython-38.pyc b/models/__pycache__/utils.cpython-38.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..780dcc063c528342db6230aeb8b360683fb97a17
Binary files /dev/null and b/models/__pycache__/utils.cpython-38.pyc differ
diff --git a/models/autoencoder_magvit.py b/models/autoencoder_magvit.py
new file mode 100755
index 0000000000000000000000000000000000000000..a1ac2ec5d7aeff3a2bf548eab8b9e32655103c60
--- /dev/null
+++ b/models/autoencoder_magvit.py
@@ -0,0 +1,1296 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.utils import logging
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from diffusers.models.activations import get_activation
+from diffusers.models.downsampling import CogVideoXDownsample3D
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.upsampling import CogVideoXUpsample3D
+from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CogVideoXSafeConv3d(nn.Conv3d):
+ r"""
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
+ """
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
+
+ # Set to 2GB, suitable for CuDNN
+ if memory_count > 2:
+ kernel_size = self.kernel_size[0]
+ part_num = int(memory_count / 2) + 1
+ input_chunks = torch.chunk(input, part_num, dim=2)
+
+ if kernel_size > 1:
+ input_chunks = [input_chunks[0]] + [
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
+ for i in range(1, len(input_chunks))
+ ]
+
+ output_chunks = []
+ for input_chunk in input_chunks:
+ output_chunks.append(super().forward(input_chunk))
+ output = torch.cat(output_chunks, dim=2)
+ return output
+ else:
+ return super().forward(input)
+
+
+class CogVideoXCausalConv3d(nn.Module):
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
+
+ Args:
+ in_channels (`int`): Number of channels in the input tensor.
+ out_channels (`int`): Number of output channels produced by the convolution.
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
+ stride (`int`, defaults to `1`): Stride of the convolution.
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ stride: int = 1,
+ dilation: int = 1,
+ pad_mode: str = "constant",
+ ):
+ super().__init__()
+
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size,) * 3
+
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+ self.pad_mode = pad_mode
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+
+ self.height_pad = height_pad
+ self.width_pad = width_pad
+ self.time_pad = time_pad
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
+
+ self.temporal_dim = 2
+ self.time_kernel_size = time_kernel_size
+
+ stride = (stride, 1, 1)
+ dilation = (dilation, 1, 1)
+ self.conv = CogVideoXSafeConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ )
+
+ self.conv_cache = None
+
+ def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ kernel_size = self.time_kernel_size
+ if kernel_size > 1:
+ cached_inputs = (
+ [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
+ )
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
+ return inputs
+
+ def _clear_fake_context_parallel_cache(self):
+ del self.conv_cache
+ self.conv_cache = None
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ inputs = self.fake_context_parallel_forward(inputs)
+
+ self._clear_fake_context_parallel_cache()
+ # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
+ # hundred megabytes and so let's not do it for now
+ self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
+
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
+
+ output = self.conv(inputs)
+ return output
+
+
+class CogVideoXSpatialNorm3D(nn.Module):
+ r"""
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
+ to 3D-video like data.
+
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
+
+ Args:
+ f_channels (`int`):
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
+ zq_channels (`int`):
+ The number of channels for the quantized vector as described in the paper.
+ groups (`int`):
+ Number of groups to separate the channels into for group normalization.
+ """
+
+ def __init__(
+ self,
+ f_channels: int,
+ zq_channels: int,
+ groups: int = 32,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
+ z_first = F.interpolate(z_first, size=f_first_size)
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
+ zq = torch.cat([z_first, z_rest], dim=2)
+ else:
+ zq = F.interpolate(zq, size=f.shape[-3:])
+
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+class CogVideoXResnetBlock3D(nn.Module):
+ r"""
+ A 3D ResNet block used in the CogVideoX model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ non_linearity (`str`, defaults to `"swish"`):
+ Activation function to use.
+ conv_shortcut (bool, defaults to `False`):
+ Whether or not to use a convolution shortcut.
+ spatial_norm_dim (`int`, *optional*):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ dropout: float = 0.0,
+ temb_channels: int = 512,
+ groups: int = 32,
+ eps: float = 1e-6,
+ non_linearity: str = "swish",
+ conv_shortcut: bool = False,
+ spatial_norm_dim: Optional[int] = None,
+ pad_mode: str = "first",
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.nonlinearity = get_activation(non_linearity)
+ self.use_conv_shortcut = conv_shortcut
+
+ if spatial_norm_dim is None:
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
+ else:
+ self.norm1 = CogVideoXSpatialNorm3D(
+ f_channels=in_channels,
+ zq_channels=spatial_norm_dim,
+ groups=groups,
+ )
+ self.norm2 = CogVideoXSpatialNorm3D(
+ f_channels=out_channels,
+ zq_channels=spatial_norm_dim,
+ groups=groups,
+ )
+
+ self.conv1 = CogVideoXCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ if temb_channels > 0:
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
+
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = CogVideoXCausalConv3d(
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CogVideoXCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+ else:
+ self.conv_shortcut = CogVideoXSafeConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ zq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = inputs
+
+ if zq is not None:
+ hidden_states = self.norm1(hidden_states, zq)
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if zq is not None:
+ hidden_states = self.norm2(hidden_states, zq)
+ else:
+ hidden_states = self.norm2(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels:
+ inputs = self.conv_shortcut(inputs)
+
+ hidden_states = hidden_states + inputs
+ return hidden_states
+
+
+class CogVideoXDownBlock3D(nn.Module):
+ r"""
+ A downsampling block used in the CogVideoX model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ add_downsample (`bool`, defaults to `True`):
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to downsample across temporal dimension.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ add_downsample: bool = True,
+ downsample_padding: int = 0,
+ compress_time: bool = False,
+ pad_mode: str = "first",
+ ):
+ super().__init__()
+
+ resnets = []
+ for i in range(num_layers):
+ in_channel = in_channels if i == 0 else out_channels
+ resnets.append(
+ CogVideoXResnetBlock3D(
+ in_channels=in_channel,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=resnet_groups,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ pad_mode=pad_mode,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.downsamplers = None
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ CogVideoXDownsample3D(
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
+ )
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ zq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ for resnet in self.resnets:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def create_forward(*inputs):
+ return module(*inputs)
+
+ return create_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, zq
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class CogVideoXMidBlock3D(nn.Module):
+ r"""
+ A middle block used in the CogVideoX model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ spatial_norm_dim (`int`, *optional*):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ spatial_norm_dim: Optional[int] = None,
+ pad_mode: str = "first",
+ ):
+ super().__init__()
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ CogVideoXResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=resnet_groups,
+ eps=resnet_eps,
+ spatial_norm_dim=spatial_norm_dim,
+ non_linearity=resnet_act_fn,
+ pad_mode=pad_mode,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ zq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ for resnet in self.resnets:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def create_forward(*inputs):
+ return module(*inputs)
+
+ return create_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, zq
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+
+ return hidden_states
+
+
+class CogVideoXUpBlock3D(nn.Module):
+ r"""
+ An upsampling block used in the CogVideoX model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ spatial_norm_dim (`int`, defaults to `16`):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ add_upsample (`bool`, defaults to `True`):
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to downsample across temporal dimension.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ spatial_norm_dim: int = 16,
+ add_upsample: bool = True,
+ upsample_padding: int = 1,
+ compress_time: bool = False,
+ pad_mode: str = "first",
+ ):
+ super().__init__()
+
+ resnets = []
+ for i in range(num_layers):
+ in_channel = in_channels if i == 0 else out_channels
+ resnets.append(
+ CogVideoXResnetBlock3D(
+ in_channels=in_channel,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=resnet_groups,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ spatial_norm_dim=spatial_norm_dim,
+ pad_mode=pad_mode,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.upsamplers = None
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ CogVideoXUpsample3D(
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
+ )
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ zq: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
+ for resnet in self.resnets:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def create_forward(*inputs):
+ return module(*inputs)
+
+ return create_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, zq
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CogVideoXEncoder3D(nn.Module):
+ r"""
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
+ options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 16,
+ down_block_types: Tuple[str, ...] = (
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
+ layers_per_block: int = 3,
+ act_fn: str = "silu",
+ norm_eps: float = 1e-6,
+ norm_num_groups: int = 32,
+ dropout: float = 0.0,
+ pad_mode: str = "first",
+ temporal_compression_ratio: float = 4,
+ ):
+ super().__init__()
+
+ # log2 of temporal_compress_times
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
+
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
+ self.down_blocks = nn.ModuleList([])
+
+ # down blocks
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ compress_time = i < temporal_compress_level
+
+ if down_block_type == "CogVideoXDownBlock3D":
+ down_block = CogVideoXDownBlock3D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=0,
+ dropout=dropout,
+ num_layers=layers_per_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ add_downsample=not is_final_block,
+ compress_time=compress_time,
+ )
+ else:
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
+
+ self.down_blocks.append(down_block)
+
+ # mid block
+ self.mid_block = CogVideoXMidBlock3D(
+ in_channels=block_out_channels[-1],
+ temb_channels=0,
+ dropout=dropout,
+ num_layers=2,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ pad_mode=pad_mode,
+ )
+
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CogVideoXCausalConv3d(
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
+ hidden_states = self.conv_in(sample)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ # 1. Down
+ for down_block in self.down_blocks:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(down_block), hidden_states, temb, None
+ )
+
+ # 2. Mid
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), hidden_states, temb, None
+ )
+ else:
+ # 1. Down
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states, temb, None)
+
+ # 2. Mid
+ hidden_states = self.mid_block(hidden_states, temb, None)
+
+ # 3. Post-process
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class CogVideoXDecoder3D(nn.Module):
+ r"""
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
+ sample.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = (
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
+ layers_per_block: int = 3,
+ act_fn: str = "silu",
+ norm_eps: float = 1e-6,
+ norm_num_groups: int = 32,
+ dropout: float = 0.0,
+ pad_mode: str = "first",
+ temporal_compression_ratio: float = 4,
+ ):
+ super().__init__()
+
+ reversed_block_out_channels = list(reversed(block_out_channels))
+
+ self.conv_in = CogVideoXCausalConv3d(
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
+ )
+
+ # mid block
+ self.mid_block = CogVideoXMidBlock3D(
+ in_channels=reversed_block_out_channels[0],
+ temb_channels=0,
+ num_layers=2,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ spatial_norm_dim=in_channels,
+ pad_mode=pad_mode,
+ )
+
+ # up blocks
+ self.up_blocks = nn.ModuleList([])
+
+ output_channel = reversed_block_out_channels[0]
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
+
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ compress_time = i < temporal_compress_level
+
+ if up_block_type == "CogVideoXUpBlock3D":
+ up_block = CogVideoXUpBlock3D(
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ temb_channels=0,
+ dropout=dropout,
+ num_layers=layers_per_block + 1,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ spatial_norm_dim=in_channels,
+ add_upsample=not is_final_block,
+ compress_time=compress_time,
+ pad_mode=pad_mode,
+ )
+ prev_output_channel = output_channel
+ else:
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
+
+ self.up_blocks.append(up_block)
+
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CogVideoXCausalConv3d(
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
+ hidden_states = self.conv_in(sample)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ # 1. Mid
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), hidden_states, temb, sample
+ )
+
+ # 2. Up
+ for up_block in self.up_blocks:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block), hidden_states, temb, sample
+ )
+ else:
+ # 1. Mid
+ hidden_states = self.mid_block(hidden_states, temb, sample)
+
+ # 2. Up
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states, temb, sample)
+
+ # 3. Post-process
+ hidden_states = self.norm_out(hidden_states, sample)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
+ [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = (
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ ),
+ up_block_types: Tuple[str] = (
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ ),
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
+ latent_channels: int = 16,
+ layers_per_block: int = 3,
+ act_fn: str = "silu",
+ norm_eps: float = 1e-6,
+ norm_num_groups: int = 32,
+ temporal_compression_ratio: float = 4,
+ sample_height: int = 480,
+ sample_width: int = 720,
+ scaling_factor: float = 1.15258426,
+ shift_factor: Optional[float] = None,
+ latents_mean: Optional[Tuple[float]] = None,
+ latents_std: Optional[Tuple[float]] = None,
+ force_upcast: float = True,
+ use_quant_conv: bool = False,
+ use_post_quant_conv: bool = False,
+ ):
+ super().__init__()
+
+ self.encoder = CogVideoXEncoder3D(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_eps=norm_eps,
+ norm_num_groups=norm_num_groups,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+ self.decoder = CogVideoXDecoder3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_eps=norm_eps,
+ norm_num_groups=norm_num_groups,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
+ # If you decode X latent frames together, the number of output frames is:
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
+ #
+ # Example with num_latent_frames_batch_size = 2:
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
+ # => 6 * 8 = 48 frames
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
+ # => 1 * 9 + 5 * 8 = 49 frames
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
+ # number of temporal frames.
+ self.num_latent_frames_batch_size = 2
+
+ # We make the minimum height and width of sample for tiling half that of the generally supported
+ self.tile_sample_min_height = sample_height // 2
+ self.tile_sample_min_width = sample_width // 2
+ self.tile_latent_min_height = int(
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
+ )
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
+
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
+ # and so the tiling implementation has only been tested on those specific resolutions.
+ self.tile_overlap_factor_height = 1 / 6
+ self.tile_overlap_factor_width = 1 / 5
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
+ module.gradient_checkpointing = value
+
+ def _clear_fake_context_parallel_cache(self):
+ for name, module in self.named_modules():
+ if isinstance(module, CogVideoXCausalConv3d):
+ logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
+ module._clear_fake_context_parallel_cache()
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_overlap_factor_height: Optional[float] = None,
+ tile_overlap_factor_width: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_overlap_factor_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
+ value might cause more tiles to be processed leading to slow down of the decoding process.
+ tile_overlap_factor_width (`int`, *optional*):
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
+ value might cause more tiles to be processed leading to slow down of the decoding process.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_latent_min_height = int(
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
+ )
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ batch_size, num_channels, num_frames, height, width = x.shape
+ if num_frames == 1:
+ h = self.encoder(x)
+ if self.quant_conv is not None:
+ h = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(h)
+ else:
+ frame_batch_size = 4
+ h = []
+ for i in range(num_frames // frame_batch_size):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
+ z_intermediate = x[:, :, start_frame:end_frame]
+ z_intermediate = self.encoder(z_intermediate)
+ if self.quant_conv is not None:
+ z_intermediate = self.quant_conv(z_intermediate)
+ h.append(z_intermediate)
+ self._clear_fake_context_parallel_cache()
+ h = torch.cat(h, dim=2)
+ posterior = DiagonalGaussianDistribution(h)
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ if num_frames == 1:
+ dec = []
+ z_intermediate = z
+ if self.post_quant_conv is not None:
+ z_intermediate = self.post_quant_conv(z_intermediate)
+ z_intermediate = self.decoder(z_intermediate)
+ dec.append(z_intermediate)
+ else:
+ frame_batch_size = self.num_latent_frames_batch_size
+ dec = []
+ for i in range(num_frames // frame_batch_size):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
+ z_intermediate = z[:, :, start_frame:end_frame]
+ if self.post_quant_conv is not None:
+ z_intermediate = self.post_quant_conv(z_intermediate)
+ z_intermediate = self.decoder(z_intermediate)
+ dec.append(z_intermediate)
+
+ self._clear_fake_context_parallel_cache()
+ dec = torch.cat(dec, dim=2)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ # Rough memory assessment:
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
+ # - Assume fp16 (2 bytes per value).
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
+ #
+ # Memory assessment when using tiling:
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
+
+ batch_size, num_channels, num_frames, height, width = z.shape
+
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
+ frame_batch_size = self.num_latent_frames_batch_size
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ time = []
+ for k in range(num_frames // frame_batch_size):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
+ tile = z[
+ :,
+ :,
+ start_frame:end_frame,
+ i : i + self.tile_latent_min_height,
+ j : j + self.tile_latent_min_width,
+ ]
+ if self.post_quant_conv is not None:
+ tile = self.post_quant_conv(tile)
+ tile = self.decoder(tile)
+ time.append(tile)
+ self._clear_fake_context_parallel_cache()
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ if not return_dict:
+ return (dec,)
+ return dec
diff --git a/models/crosstransformer3d.py b/models/crosstransformer3d.py
new file mode 100755
index 0000000000000000000000000000000000000000..9438f5fb92b0b54c10c9aa8d5a3334a615826519
--- /dev/null
+++ b/models/crosstransformer3d.py
@@ -0,0 +1,893 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import os
+import json
+import torch
+import glob
+import torch.nn.functional as F
+from torch import nn
+import math
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import is_torch_version, logging
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+class CogVideoXPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ embed_dim: int = 1920,
+ text_embed_dim: int = 4096,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ r"""
+ Args:
+ text_embeds (`torch.Tensor`):
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ text_embeds = self.text_proj(text_embeds)
+
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+
+ embeds = torch.cat(
+ [text_embeds, image_embeds], dim=1
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
+ return embeds
+
+class RefPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ embed_dim: int = 1920,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+
+ def forward(self, image_embeds: torch.Tensor):
+ r"""
+ Args:
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+ return image_embeds
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+def reshape_tensor(x, heads):
+ """
+ Reshapes the input tensor for multi-head attention.
+
+ Args:
+ x (torch.Tensor): The input tensor with shape (batch_size, length, width).
+ heads (int): The number of attention heads.
+
+ Returns:
+ torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width).
+ """
+ bs, length, width = x.shape
+ x = x.view(bs, length, heads, -1)
+ x = x.transpose(1, 2)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+class PerceiverCrossAttention(nn.Module):
+ """
+
+ Args:
+ dim (int): Dimension of the input latent and output. Default is 3072.
+ dim_head (int): Dimension of each attention head. Default is 128.
+ heads (int): Number of attention heads. Default is 16.
+ kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
+
+ Attributes:
+ scale (float): Scaling factor used in dot-product attention for numerical stability.
+ norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
+ norm2 (nn.LayerNorm): Layer normalization applied to the latent features.
+ to_q (nn.Linear): Linear layer for projecting the latent features into queries.
+ to_kv (nn.Linear): Linear layer for projecting the input features into keys and values.
+ to_out (nn.Linear): Linear layer for outputting the final result after attention.
+
+ """
+
+ def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ # Layer normalization to stabilize training
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ # Linear transformations to produce queries, keys, and values
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+
+ Args:
+ x (torch.Tensor): Input image features with shape (batch_size, n1, D), where:
+ - batch_size (b): Number of samples in the batch.
+ - n1: Sequence length (e.g., number of patches or tokens).
+ - D: Feature dimension.
+
+ latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
+ - n2: Number of latent elements.
+
+ Returns:
+ torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
+
+ """
+ # Apply layer normalization to the input image and latent features
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, seq_len, _ = latents.shape
+
+ # Compute queries, keys, and values
+ q = self.to_q(latents)
+ k, v = self.to_kv(x).chunk(2, dim=-1)
+
+ # Reshape tensors to split into attention heads
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # Compute attention weights
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable scaling than post-division
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+
+ # Compute the output via weighted combination of values
+ out = weight @ v
+
+ # Reshape and permute to prepare for final linear transformation
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
+
+ return self.to_out(out)
+
+
+class CrossTransformer3DModel(ModelMixin, ConfigMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ add_noise_in_inpaint_model: bool = False,
+ is_train_cross: bool = False,
+ cross_attn_in_channels: int = 16,
+ cross_attn_interval: int = 2,
+ cross_attn_dim_head: int = 128,
+ cross_attn_num_heads: int = 16,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ post_patch_height = sample_height // patch_size
+ post_patch_width = sample_width // patch_size
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+ self.post_patch_height = post_patch_height
+ self.post_patch_width = post_patch_width
+ self.post_time_compression_frames = post_time_compression_frames
+ self.patch_size = patch_size
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. 3D positional embeddings
+ spatial_pos_embedding = get_3d_sincos_pos_embed(
+ inner_dim,
+ (post_patch_width, post_patch_height),
+ post_time_compression_frames,
+ spatial_interpolation_scale,
+ temporal_interpolation_scale,
+ )
+ spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
+ pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
+ pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=False)
+
+ # 3. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 4. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 5. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ self.is_train_cross = is_train_cross
+ if is_train_cross:
+ # cross configs
+ self.inner_dim = inner_dim
+ self.cross_attn_interval = cross_attn_interval
+ self.num_cross_attn = num_layers // cross_attn_interval
+ self.cross_attn_dim_head = cross_attn_dim_head
+ self.cross_attn_num_heads = cross_attn_num_heads
+ self.cross_attn_kv_dim = None
+ self.ref_patch_embed = RefPatchEmbed(patch_size, cross_attn_in_channels, inner_dim, bias=True)
+ self._init_cross_inputs()
+
+ def _init_cross_inputs(self):
+ device = self.device
+ weight_dtype = self.dtype
+ self.perceiver_cross_attention = nn.ModuleList(
+ [
+ PerceiverCrossAttention(
+ dim=self.inner_dim,
+ dim_head=self.cross_attn_dim_head,
+ heads=self.cross_attn_num_heads,
+ kv_dim=self.cross_attn_kv_dim,
+ ).to(device, dtype=weight_dtype)
+ for _ in range(self.num_cross_attn)
+ ]
+ )
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor, #noise
+ encoder_hidden_states: torch.Tensor, #text
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ inpaint_latents: Optional[torch.Tensor] = None, #condition
+ cross_latents: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ return_dict: bool = True,
+ ):
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ # [2, 13, 16, 48, 84] cat [2, 13, 17, 48, 84] = [2, 13, 33, 48, 84]
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ if self.is_train_cross:
+ cross_hidden_states = self.ref_patch_embed(cross_latents)
+
+ # 3. Position embedding
+ text_seq_length = encoder_hidden_states.shape[1]
+ if not self.config.use_rotary_positional_embeddings:
+ seq_length = height * width * num_frames // (self.config.patch_size**2)
+ # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
+ pos_embeds = self.pos_embedding
+ emb_size = hidden_states.size()[-1]
+ pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
+ pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
+ pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
+ pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
+ hidden_states = hidden_states + pos_embeds
+ hidden_states = self.embedding_dropout(hidden_states)
+ # seperate
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Transformer blocks
+
+ ca_idx = 0
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ if self.is_train_cross:
+ if i % self.cross_attn_interval == 0:
+ hidden_states = hidden_states + self.perceiver_cross_attention[ca_idx](
+ cross_hidden_states, hidden_states
+ ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
+ ca_idx += 1
+
+ # if not self.config.use_rotary_positional_embeddings:
+ # # CogVideoX-2B
+ # hidden_states = self.norm_final(hidden_states)
+ # else:
+ # use CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 5. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 6. Unpatchify
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **transformer_additional_kwargs)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
+ if os.path.exists(model_file):
+ state_dict = torch.load(model_file, map_location="cpu")
+ elif os.path.exists(model_file_safetensors):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(model_file_safetensors)
+ else:
+ from safetensors.torch import load_file, safe_open
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
+ state_dict = {}
+ for model_file_safetensors in model_files_safetensors:
+ _state_dict = load_file(model_file_safetensors)
+ for key in _state_dict:
+ state_dict[key] = _state_dict[key]
+
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
+ if len(new_shape) == 5:
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
+ else:
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+ else:
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+
+ tmp_state_dict = {}
+ for key in state_dict:
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
+ tmp_state_dict[key] = state_dict[key]
+ else:
+ print(key, "Size don't match, skip")
+ state_dict = tmp_state_dict
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ print(m)
+
+ params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
+ print(f"### Mamba Parameters: {sum(params) / 1e6} M")
+
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
+
+ return model
+
+ @classmethod
+ def from_pretrained_cus(cls, pretrained_model_path, subfolder=None, config_path=None, transformer_additional_kwargs={}):
+ if subfolder:
+ config_path = config_path or pretrained_model_path
+ config_file = os.path.join(config_path, subfolder, 'config.json')
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ else:
+ config_file = os.path.join(config_path or pretrained_model_path, 'config.json')
+
+ print(f"Loading 3D transformer's pretrained weights from {pretrained_model_path} ...")
+
+ # Check if config file exists
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"Configuration file '{config_file}' does not exist")
+
+ # Load the configuration
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **transformer_additional_kwargs)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
+ if os.path.exists(model_file):
+ state_dict = torch.load(model_file, map_location="cpu")
+ elif os.path.exists(model_file_safetensors):
+ from safetensors.torch import load_file
+ state_dict = load_file(model_file_safetensors)
+ else:
+ from safetensors.torch import load_file
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
+ state_dict = {}
+ for model_file_safetensors in model_files_safetensors:
+ _state_dict = load_file(model_file_safetensors)
+ for key in _state_dict:
+ state_dict[key] = _state_dict[key]
+
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
+ if len(new_shape) == 5:
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
+ else:
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+ else:
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+
+ tmp_state_dict = {}
+ for key in state_dict:
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
+ tmp_state_dict[key] = state_dict[key]
+ else:
+ print(key, "Size don't match, skip")
+ state_dict = tmp_state_dict
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ print(m)
+
+ params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
+ print(f"### Mamba Parameters: {sum(params) / 1e6} M")
+
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
+
+ return model
+
+if __name__ == '__main__':
+ device = "cuda:0"
+ weight_dtype = torch.bfloat16
+ model_path = "/group/40075/wangboyu/CogVideoX-Fun/CogVideoX-Fun-V1.1-5b-InP"
+
+ transformer_additional_kwargs={
+ 'is_train_cross': True,
+ 'cross_attn_in_channels': 16,
+ 'cross_attn_interval': 2,
+ 'cross_attn_dim_head' : 128,
+ 'cross_attn_num_heads':16,
+ }
+
+ transformer = CrossTransformer3DModel.from_pretrained_2d(
+ model_path,
+ subfolder="transformer",
+ transformer_additional_kwargs=transformer_additional_kwargs,
+ )
+
+ transformer.to(device, dtype=weight_dtype)
+ for param in transformer.parameters():
+ param.requires_grad = False
+ transformer.eval()
+
+ b = 1
+ dim = 16
+ noisy_latents = torch.ones(b, 13, dim, 60, 90).to(device, dtype=weight_dtype)
+ inpaint_latents = torch.ones(b, 13, dim+1, 60, 90).to(device, dtype=weight_dtype)
+ # cross_latents = torch.ones(b, 13, dim, 60, 90).to(device, dtype=weight_dtype)
+ cross_latents = torch.ones(b, 1, dim, 60, 90).to(device, dtype=weight_dtype)
+ prompt_embeds = torch.ones(b, 226, 4096).to(device, dtype=weight_dtype)
+ image_rotary_emb = (torch.ones(17550, 64).to(device, dtype=weight_dtype), torch.ones(17550, 64).to(device, dtype=weight_dtype))
+ timesteps = torch.tensor([311]).to(device, dtype=weight_dtype)
+ assert len(timesteps) == b
+
+ model_output = transformer(
+ hidden_states=noisy_latents,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timesteps,
+ inpaint_latents=inpaint_latents,
+ cross_latents = cross_latents,
+ image_rotary_emb=image_rotary_emb,
+ return_dict=False,
+ )[0]
+
+ print(model_output)
diff --git a/models/pipeline_trajectorycrafter.py b/models/pipeline_trajectorycrafter.py
new file mode 100755
index 0000000000000000000000000000000000000000..00b02424f6315057fa6cb13e89354df9fefecb68
--- /dev/null
+++ b/models/pipeline_trajectorycrafter.py
@@ -0,0 +1,1005 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.models import AutoencoderKLCogVideoX
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from diffusers.image_processor import VaeImageProcessor
+from einops import rearrange
+from models.crosstransformer3d import CrossTransformer3DModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+def resize_mask(mask, latent, process_first_frame_only=True):
+ latent_size = latent.size()
+ batch_size, channels, num_frames, height, width = mask.shape
+
+ if process_first_frame_only:
+ target_size = list(latent_size[2:])
+ target_size[0] = 1
+ first_frame_resized = F.interpolate(
+ mask[:, :, 0:1, :, :],
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+
+ target_size = list(latent_size[2:])
+ target_size[0] = target_size[0] - 1
+ if target_size[0] != 0:
+ remaining_frames_resized = F.interpolate(
+ mask[:, :, 1:, :, :],
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+ else:
+ resized_mask = first_frame_resized
+ else:
+ target_size = list(latent_size[2:])
+ resized_mask = F.interpolate(
+ mask,
+ size=target_size,
+ mode='trilinear',
+ align_corners=False
+ )
+ return resized_mask
+
+
+def add_noise_to_reference_video(image, ratio=None):
+ if ratio is None:
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
+ sigma = torch.exp(sigma).to(image.dtype)
+ else:
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
+
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
+ image = image + image_noise
+ return image
+
+
+@dataclass
+class CogVideoX_Fun_PipelineOutput(BaseOutput):
+ r"""
+ Output class for CogVideo pipelines.
+
+ Args:
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ videos: torch.Tensor
+
+
+class TrajCrafter_Pipeline(DiffusionPipeline):
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CrossTransformer3DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ video=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_video_latents=False,
+ ):
+ shape = (
+ batch_size,
+ (video_length - 1) // self.vae_scale_factor_temporal + 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if return_video_latents or (latents is None and not is_strength_max):
+ video = video.to(device=device, dtype=self.vae.dtype)
+
+ bs = 1
+ new_video = []
+ for i in range(0, video.shape[0], bs):
+ video_bs = video[i : i + bs]
+ video_bs = self.vae.encode(video_bs)[0]
+ video_bs = video_bs.sample()
+ new_video.append(video_bs)
+ video = torch.cat(new_video, dim = 0)
+ video = video * self.vae.config.scaling_factor
+
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
+ video_latents = video_latents.to(device=device, dtype=dtype)
+ video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
+
+ if latents is None: #this branch
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ else:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_video_latents:
+ outputs += (video_latents,)
+
+ return outputs
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+
+ if mask is not None:
+ mask = mask.to(device=device, dtype=self.vae.dtype)
+ bs = 1
+ new_mask = []
+ for i in range(0, mask.shape[0], bs):
+ mask_bs = mask[i : i + bs]
+ mask_bs = self.vae.encode(mask_bs)[0]
+ mask_bs = mask_bs.mode()
+ new_mask.append(mask_bs)
+ mask = torch.cat(new_mask, dim = 0)
+ mask = mask * self.vae.config.scaling_factor
+
+ if masked_image is not None:
+ if self.transformer.config.add_noise_in_inpaint_model:
+ masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
+ bs = 1
+ new_mask_pixel_values = []
+ for i in range(0, masked_image.shape[0], bs):
+ mask_pixel_values_bs = masked_image[i : i + bs]
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
+ new_mask_pixel_values.append(mask_pixel_values_bs)
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
+ else:
+ masked_image_latents = None
+
+ return mask, masked_image_latents
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ frames = self.vae.decode(latents).sample
+ frames = (frames / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ frames = frames.cpu().float().numpy()
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ use_real=True,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ video: Union[torch.FloatTensor] = None,
+ mask_video: Union[torch.FloatTensor] = None,
+ reference: Union[torch.FloatTensor] = None,
+ masked_video_latents: Union[torch.FloatTensor] = None,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "numpy",
+ return_dict: bool = False,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ strength: float = 1,
+ noise_aug_strength: float = 0.0563,
+ comfyui_progressbar: bool = False,
+ ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_frames (`int`, defaults to `48`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if num_frames > 49:
+ raise ValueError(
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ self._num_timesteps = len(timesteps)
+ if comfyui_progressbar:
+ from comfy.utils import ProgressBar
+ pbar = ProgressBar(num_inference_steps + 2)
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Prepare latents.
+ if video is not None:
+ video_length = video.shape[2]
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
+ init_video = init_video.to(dtype=torch.float32)
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
+ else:
+ init_video = None
+
+ ref_length = reference.shape[2]
+ ref_video = self.image_processor.preprocess(rearrange(reference, "b c f h w -> (b f) c h w"), height=height, width=width)
+ ref_video = rearrange(ref_video, "(b f) c h w -> b c f h w", f=ref_length)
+ bs = 1
+ ref_video = ref_video.to(device=device, dtype=self.vae.dtype)
+ new_ref_video = []
+ for i in range(0, ref_video.shape[0], bs):
+ video_bs = ref_video[i : i + bs]
+ video_bs = self.vae.encode(video_bs)[0]
+ video_bs = video_bs.sample()
+ new_ref_video.append(video_bs)
+ new_ref_video = torch.cat(new_ref_video, dim = 0)
+ new_ref_video = new_ref_video * self.vae.config.scaling_factor
+ ref_latents = new_ref_video.repeat(batch_size // new_ref_video.shape[0], 1, 1, 1, 1)
+ ref_latents = ref_latents.to(device=self.device, dtype=self.dtype)
+ ref_latents = rearrange(ref_latents, "b c f h w -> b f c h w")
+ ref_input = torch.cat([ref_latents] * 2) if do_classifier_free_guidance else ref_latents
+
+
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_transformer = self.transformer.config.in_channels
+ return_image_latents = num_channels_transformer == num_channels_latents
+
+ latents_outputs = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ video=init_video,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_video_latents=return_image_latents,
+ )
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+ if comfyui_progressbar:
+ pbar.update(1)
+ # [1, 3, 49, 384, 672] to [1, 13, 16, 48, 84]
+ if mask_video is not None:
+ if (mask_video == 255).all():
+ mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
+
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+ )
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
+ else:
+ # Prepare mask latent variables
+ video_length = video.shape[2]
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
+ mask_condition = mask_condition.to(dtype=torch.float32)
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
+ #[0,1]
+ if num_channels_transformer != num_channels_latents:
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
+ if masked_video_latents is None:
+ #在 mask_condition_tile 小于 0.5(即0,首帧) 的位置,masked_video 保留 init_video 的值;在 mask_condition_tile 大于 0.5(即1) 的位置,masked_video 的值被设置为 -1
+ masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
+ else:
+ masked_video = masked_video_latents
+
+ _, masked_video_latents = self.prepare_mask_latents(
+ None,
+ masked_video,
+ batch_size,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ noise_aug_strength=noise_aug_strength,
+ )
+ # mask at latent size, 1 is valid,第一帧变成1,后面变成0
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
+ #缩放1的数值
+ mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
+
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
+
+ # input is with cfg guidance
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+ )
+
+ mask = rearrange(mask, "b c f h w -> b f c h w")
+ mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
+ masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
+ # channel cat
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
+ else:
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
+ mask = rearrange(mask, "b c f h w -> b f c h w")
+
+ inpaint_latents = None
+ else:
+ if num_channels_transformer != num_channels_latents:
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
+
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+ )
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
+ else:
+ mask = torch.zeros_like(init_video[:, :1])
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
+ mask = rearrange(mask, "b c f h w -> b f c h w")
+
+ inpaint_latents = None
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) # h w t
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ # 输入普通latents(input image repeat成视频和带mask的latents)
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ return_dict=False,
+ inpaint_latents=inpaint_latents,
+ cross_latents = ref_input,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if comfyui_progressbar:
+ pbar.update(1)
+
+ if output_type == "numpy":
+ video = self.decode_latents(latents)
+ elif not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ video = torch.from_numpy(video)
+
+ return CogVideoX_Fun_PipelineOutput(videos=video)
diff --git a/models/utils.py b/models/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..bc1e335cc4e3e3e8cf0755f03b71546661fec4fe
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,430 @@
+import numpy as np
+import cv2
+import PIL
+from PIL import Image
+import os
+from datetime import datetime
+import pdb
+import torch.nn.functional as F
+import numpy as np
+import os
+import cv2
+import copy
+from scipy.interpolate import UnivariateSpline, interp1d
+import numpy as np
+import PIL.Image
+import torch
+import torchvision
+from tqdm import tqdm
+from pathlib import Path
+from typing import Tuple, Optional
+import cv2
+import PIL
+import numpy
+import skimage.io
+import torch
+import torch.nn.functional as F
+from decord import VideoReader, cpu
+
+def read_video_frames(video_path, process_length, stride, max_res, dataset="open"):
+ if dataset == "open":
+ print("==> processing video: ", video_path)
+ vid = VideoReader(video_path, ctx=cpu(0))
+ print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
+ # original_height, original_width = vid.get_batch([0]).shape[1:3]
+ # height = round(original_height / 64) * 64
+ # width = round(original_width / 64) * 64
+ # if max(height, width) > max_res:
+ # scale = max_res / max(original_height, original_width)
+ # height = round(original_height * scale / 64) * 64
+ # width = round(original_width * scale / 64) * 64
+
+ #FIXME: hard coded
+ width = 1024
+ height = 576
+
+ vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
+
+ frames_idx = list(range(0, len(vid), stride))
+ print(
+ f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}"
+ )
+ if process_length != -1 and process_length < len(frames_idx):
+ frames_idx = frames_idx[:process_length]
+ print(
+ f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
+ )
+ frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0
+
+ return frames
+
+
+def save_video(data,images_path,folder=None,fps=8):
+ if isinstance(data, np.ndarray):
+ tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8)
+ elif isinstance(data, torch.Tensor):
+ tensor_data = (data.detach().cpu() * 255).to(torch.uint8)
+ elif isinstance(data, list):
+ folder = [folder]*len(data)
+ images = [np.array(Image.open(os.path.join(folder_name,path))) for folder_name,path in zip(folder,data)]
+ stacked_images = np.stack(images, axis=0)
+ tensor_data = torch.from_numpy(stacked_images).to(torch.uint8)
+ torchvision.io.write_video(images_path, tensor_data, fps=fps, video_codec='h264', options={'crf': '10'})
+
+def sphere2pose(c2ws_input, theta, phi, r,device,x=None,y=None):
+ c2ws = copy.deepcopy(c2ws_input)
+ # c2ws[:,2, 3] = c2ws[:,2, 3] - radius
+
+ #先沿着世界坐标系z轴方向平移再旋转
+ c2ws[:,2,3] -= r
+ if x is not None:
+ c2ws[:,1,3] += y
+ if y is not None:
+ c2ws[:,0,3] -= x
+
+ theta = torch.deg2rad(torch.tensor(theta)).to(device)
+ sin_value_x = torch.sin(theta)
+ cos_value_x = torch.cos(theta)
+ rot_mat_x = torch.tensor([[1, 0, 0, 0],
+ [0, cos_value_x, -sin_value_x, 0],
+ [0, sin_value_x, cos_value_x, 0],
+ [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device)
+
+ phi = torch.deg2rad(torch.tensor(phi)).to(device)
+ sin_value_y = torch.sin(phi)
+ cos_value_y = torch.cos(phi)
+ rot_mat_y = torch.tensor([[cos_value_y, 0, sin_value_y, 0],
+ [0, 1, 0, 0],
+ [-sin_value_y, 0, cos_value_y, 0],
+ [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device)
+
+ c2ws = torch.matmul(rot_mat_x,c2ws)
+ c2ws = torch.matmul(rot_mat_y,c2ws)
+ # c2ws[:,2, 3] = c2ws[:,2, 3] + radius
+ return c2ws
+
+def generate_traj_specified(c2ws_anchor,theta, phi,d_r,d_x,d_y,frame,device):
+ # Initialize a camera.
+ thetas = np.linspace(0,theta,frame)
+ phis = np.linspace(0,phi,frame)
+ rs = np.linspace(0,d_r,frame)
+ xs = np.linspace(0,d_x,frame)
+ ys = np.linspace(0,d_y,frame)
+ c2ws_list = []
+ for th, ph, r, x, y in zip(thetas,phis,rs, xs, ys):
+ c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device, np.float32(x),np.float32(y))
+ c2ws_list.append(c2w_new)
+ c2ws = torch.cat(c2ws_list,dim=0)
+ return c2ws
+
+def txt_interpolation(input_list,n,mode = 'smooth'):
+ x = np.linspace(0, 1, len(input_list))
+ if mode == 'smooth':
+ f = UnivariateSpline(x, input_list, k=3)
+ elif mode == 'linear':
+ f = interp1d(x, input_list)
+ else:
+ raise KeyError(f"Invalid txt interpolation mode: {mode}")
+ xnew = np.linspace(0, 1, n)
+ ynew = f(xnew)
+ return ynew
+
+def generate_traj_txt(c2ws_anchor,phi, theta, r,frame,device):
+ # Initialize a camera.
+ """
+ The camera coordinate sysmte in COLMAP is right-down-forward
+ Pytorch3D is left-up-forward
+ """
+
+ if len(phi)>3:
+ phis = txt_interpolation(phi,frame,mode='smooth')
+ phis[0] = phi[0]
+ phis[-1] = phi[-1]
+ else:
+ phis = txt_interpolation(phi,frame,mode='linear')
+
+ if len(theta)>3:
+ thetas = txt_interpolation(theta,frame,mode='smooth')
+ thetas[0] = theta[0]
+ thetas[-1] = theta[-1]
+ else:
+ thetas = txt_interpolation(theta,frame,mode='linear')
+
+ if len(r) >3:
+ rs = txt_interpolation(r,frame,mode='smooth')
+ rs[0] = r[0]
+ rs[-1] = r[-1]
+ else:
+ rs = txt_interpolation(r,frame,mode='linear')
+ # rs = rs*c2ws_anchor[0,2,3].cpu().numpy()
+
+ c2ws_list = []
+ for th, ph, r in zip(thetas,phis,rs):
+ c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device)
+ c2ws_list.append(c2w_new)
+ c2ws = torch.cat(c2ws_list,dim=0)
+ return c2ws
+
+class Warper:
+ def __init__(self, resolution: tuple = None, device: str = 'gpu0'):
+ self.resolution = resolution
+ self.device = self.get_device(device)
+ self.dtype = torch.float32
+ return
+
+ def forward_warp(self, frame1: torch.Tensor, mask1: Optional[torch.Tensor], depth1: torch.Tensor,
+ transformation1: torch.Tensor, transformation2: torch.Tensor, intrinsic1: torch.Tensor,
+ intrinsic2: Optional[torch.Tensor], mask=False, twice=False) -> \
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using
+ bilinear splatting.
+ All arrays should be torch tensors with batch dimension and channel first
+ :param frame1: (b, 3, h, w). If frame1 is not in the range [-1, 1], either set is_image=False when calling
+ bilinear_splatting on frame within this function, or modify clipping in bilinear_splatting()
+ method accordingly.
+ :param mask1: (b, 1, h, w) - 1 for known, 0 for unknown. Optional
+ :param depth1: (b, 1, h, w)
+ :param transformation1: (b, 4, 4) extrinsic transformation matrix of first view: [R, t; 0, 1]
+ :param transformation2: (b, 4, 4) extrinsic transformation matrix of second view: [R, t; 0, 1]
+ :param intrinsic1: (b, 3, 3) camera intrinsic matrix
+ :param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional
+ """
+ if self.resolution is not None:
+ assert frame1.shape[2:4] == self.resolution
+ b, c, h, w = frame1.shape
+ if mask1 is None:
+ mask1 = torch.ones(size=(b, 1, h, w)).to(frame1)
+ if intrinsic2 is None:
+ intrinsic2 = intrinsic1.clone()
+
+ assert frame1.shape == (b, 3, h, w)
+ assert mask1.shape == (b, 1, h, w)
+ assert depth1.shape == (b, 1, h, w)
+ assert transformation1.shape == (b, 4, 4)
+ assert transformation2.shape == (b, 4, 4)
+ assert intrinsic1.shape == (b, 3, 3)
+ assert intrinsic2.shape == (b, 3, 3)
+
+ frame1 = frame1.to(self.device).to(self.dtype)
+ mask1 = mask1.to(self.device).to(self.dtype)
+ depth1 = depth1.to(self.device).to(self.dtype)
+ transformation1 = transformation1.to(self.device).to(self.dtype)
+ transformation2 = transformation2.to(self.device).to(self.dtype)
+ intrinsic1 = intrinsic1.to(self.device).to(self.dtype)
+ intrinsic2 = intrinsic2.to(self.device).to(self.dtype)
+
+ trans_points1 = self.compute_transformed_points(depth1, transformation1, transformation2, intrinsic1,
+ intrinsic2)
+ trans_coordinates = trans_points1[:,:, :, :2, 0] / trans_points1[:,:, :, 2:3, 0]
+ trans_depth1 = trans_points1[:,:, :, 2, 0]
+ grid = self.create_grid(b, h, w).to(trans_coordinates)
+ flow12 = trans_coordinates.permute(0,3,1,2) - grid
+ if not twice:
+ warped_frame2, mask2 = self.bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=True)
+ if mask:
+ warped_frame2, mask2 = self.clean_points(warped_frame2, mask2)
+ return warped_frame2, mask2, None, flow12
+
+ else:
+ warped_frame2, mask2 = self.bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=True)
+ # warped_frame2, mask2 = self.clean_points(warped_frame2, mask2)
+ warped_flow, _ = self.bilinear_splatting(flow12, mask1, trans_depth1, flow12, None, is_image=False)
+ twice_warped_frame1 ,_ = self.bilinear_splatting(warped_frame2, mask2, depth1.squeeze(1), -warped_flow, None, is_image=True)
+ return twice_warped_frame1, warped_frame2, None, None
+
+ def compute_transformed_points(self, depth1: torch.Tensor, transformation1: torch.Tensor, transformation2: torch.Tensor,
+ intrinsic1: torch.Tensor, intrinsic2: Optional[torch.Tensor]):
+ """
+ Computes transformed position for each pixel location
+ """
+ if self.resolution is not None:
+ assert depth1.shape[2:4] == self.resolution
+ b, _, h, w = depth1.shape
+ if intrinsic2 is None:
+ intrinsic2 = intrinsic1.clone()
+ transformation = torch.bmm(transformation2, torch.linalg.inv(transformation1)) # (b, 4, 4)
+
+ x1d = torch.arange(0, w)[None]
+ y1d = torch.arange(0, h)[:, None]
+ x2d = x1d.repeat([h, 1]).to(depth1) # (h, w)
+ y2d = y1d.repeat([1, w]).to(depth1) # (h, w)
+ ones_2d = torch.ones(size=(h, w)).to(depth1) # (h, w)
+ ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) # (b, h, w, 1, 1)
+ pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1)
+
+ intrinsic1_inv = torch.linalg.inv(intrinsic1) # (b, 3, 3)
+ intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3)
+ intrinsic2_4d = intrinsic2[:, None, None] # (b, 1, 1, 3, 3)
+ depth_4d = depth1[:, 0][:, :, :, None, None] # (b, h, w, 1, 1)
+ trans_4d = transformation[:, None, None] # (b, 1, 1, 4, 4)
+
+ unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo) # (b, h, w, 3, 1)
+ world_points = depth_4d * unnormalized_pos # (b, h, w, 3, 1)
+ world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1)
+ trans_world_homo = torch.matmul(trans_4d, world_points_homo) # (b, h, w, 4, 1)
+ trans_world = trans_world_homo[:, :, :, :3] # (b, h, w, 3, 1)
+ trans_norm_points = torch.matmul(intrinsic2_4d, trans_world) # (b, h, w, 3, 1)
+ return trans_norm_points
+
+ def bilinear_splatting(self, frame1: torch.Tensor, mask1: Optional[torch.Tensor], depth1: torch.Tensor,
+ flow12: torch.Tensor, flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \
+ Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Bilinear splatting
+ :param frame1: (b,c,h,w)
+ :param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional
+ :param depth1: (b,1,h,w)
+ :param flow12: (b,2,h,w)
+ :param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional
+ :param is_image: if true, output will be clipped to (-1,1) range
+ :return: warped_frame2: (b,c,h,w)
+ mask2: (b,1,h,w): 1 for known and 0 for unknown
+ """
+ if self.resolution is not None:
+ assert frame1.shape[2:4] == self.resolution
+ b, c, h, w = frame1.shape
+ if mask1 is None:
+ mask1 = torch.ones(size=(b, 1, h, w)).to(frame1)
+ if flow12_mask is None:
+ flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12)
+ grid = self.create_grid(b, h, w).to(frame1)
+ trans_pos = flow12 + grid
+
+ trans_pos_offset = trans_pos + 1
+ trans_pos_floor = torch.floor(trans_pos_offset).long()
+ trans_pos_ceil = torch.ceil(trans_pos_offset).long()
+ trans_pos_offset = torch.stack([
+ torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
+ torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)], dim=1)
+ trans_pos_floor = torch.stack([
+ torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
+ torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)], dim=1)
+ trans_pos_ceil = torch.stack([
+ torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
+ torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)], dim=1)
+
+ prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
+ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
+ prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
+ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
+ prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
+ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
+ prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
+ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
+
+ sat_depth1 = torch.clamp(depth1, min=0, max=1000)
+ log_depth1 = torch.log(1 + sat_depth1)
+ depth_weights = torch.exp(log_depth1 / log_depth1.max() * 50)
+
+ weight_nw = torch.moveaxis(prox_weight_nw * mask1 * flow12_mask / depth_weights.unsqueeze(1), [0, 1, 2, 3], [0, 3, 1, 2])
+ weight_sw = torch.moveaxis(prox_weight_sw * mask1 * flow12_mask / depth_weights.unsqueeze(1), [0, 1, 2, 3], [0, 3, 1, 2])
+ weight_ne = torch.moveaxis(prox_weight_ne * mask1 * flow12_mask / depth_weights.unsqueeze(1), [0, 1, 2, 3], [0, 3, 1, 2])
+ weight_se = torch.moveaxis(prox_weight_se * mask1 * flow12_mask / depth_weights.unsqueeze(1), [0, 1, 2, 3], [0, 3, 1, 2])
+
+ warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=torch.float32).to(frame1)
+ warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=torch.float32).to(frame1)
+
+ frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2])
+ batch_indices = torch.arange(b)[:, None, None].to(frame1.device)
+ warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),frame1_cl * weight_nw, accumulate=True)
+ warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),frame1_cl * weight_sw, accumulate=True)
+ warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),frame1_cl * weight_ne, accumulate=True)
+ warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),frame1_cl * weight_se, accumulate=True)
+
+ warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),weight_nw, accumulate=True)
+ warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),weight_sw, accumulate=True)
+ warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),weight_ne, accumulate=True)
+ warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),weight_se, accumulate=True)
+
+ warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1])
+ warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1])
+ cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1]
+ cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1]
+
+ mask = cropped_weights > 0
+ zero_value = -1 if is_image else 0
+ zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device)
+ warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor)
+ mask2 = mask.to(frame1)
+
+ if is_image:
+ assert warped_frame2.min() >= -1.1 # Allow for rounding errors
+ assert warped_frame2.max() <= 1.1
+ warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1)
+ return warped_frame2, mask2
+
+ def clean_points(self, warped_frame2, mask2):
+ warped_frame2 = (warped_frame2 + 1.)/2.
+ mask = 1-mask2
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = mask.squeeze(0).repeat(3,1,1).permute(1,2,0)*255.
+ mask = mask.cpu().numpy()
+ kernel = numpy.ones((5,5), numpy.uint8)
+ mask_erosion = cv2.dilate(numpy.array(mask), kernel, iterations = 1)
+ mask_erosion = PIL.Image.fromarray(numpy.uint8(mask_erosion))
+ mask_erosion_ = numpy.array(mask_erosion)/255.
+ mask_erosion_[mask_erosion_ < 0.5] = 0
+ mask_erosion_[mask_erosion_ >= 0.5] = 1
+ mask_new = torch.from_numpy(mask_erosion_).permute(2,0,1).unsqueeze(0).to(self.device)
+ warped_frame2 = warped_frame2*(1-mask_new)
+ return warped_frame2*2.-1., 1-mask_new[:,0:1,:,:]
+
+ @staticmethod
+ def create_grid(b, h, w):
+ x_1d = torch.arange(0, w)[None]
+ y_1d = torch.arange(0, h)[:, None]
+ x_2d = x_1d.repeat([h, 1])
+ y_2d = y_1d.repeat([1, w])
+ grid = torch.stack([x_2d, y_2d], dim=0)
+ batch_grid = grid[None].repeat([b, 1, 1, 1])
+ return batch_grid
+
+ @staticmethod
+ def read_image(path: Path) -> torch.Tensor:
+ image = skimage.io.imread(path.as_posix())
+ return image
+
+ @staticmethod
+ def read_depth(path: Path) -> torch.Tensor:
+ if path.suffix == '.png':
+ depth = skimage.io.imread(path.as_posix())
+ elif path.suffix == '.npy':
+ depth = numpy.load(path.as_posix())
+ elif path.suffix == '.npz':
+ with numpy.load(path.as_posix()) as depth_data:
+ depth = depth_data['depth']
+ else:
+ raise RuntimeError(f'Unknown depth format: {path.suffix}')
+ return depth
+
+ @staticmethod
+ def camera_intrinsic_transform(capture_width=1920, capture_height=1080, patch_start_point: tuple = (0, 0)):
+ start_y, start_x = patch_start_point
+ camera_intrinsics = numpy.eye(4)
+ camera_intrinsics[0, 0] = 2100
+ camera_intrinsics[0, 2] = capture_width / 2.0 - start_x
+ camera_intrinsics[1, 1] = 2100
+ camera_intrinsics[1, 2] = capture_height / 2.0 - start_y
+ return camera_intrinsics
+
+ @staticmethod
+ def get_device(device: str):
+ """
+ Returns torch device object
+ :param device: cpu/gpu0/gpu1
+ :return:
+ """
+ if device == 'cpu':
+ device = torch.device('cpu')
+ elif device.startswith('gpu') and torch.cuda.is_available():
+ gpu_num = int(device[3:])
+ device = torch.device(f'cuda:{gpu_num}')
+ else:
+ device = torch.device('cpu')
+ return device
+
+
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100755
index 0000000000000000000000000000000000000000..254059538a44f7499c776dc320422a683348347b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+Pillow
+einops
+safetensors
+timm
+tomesd
+torchdiffeq
+torchsde
+xformers
+decord
+datasets
+numpy
+scikit-image
+opencv-python
+omegaconf
+SentencePiece
+albumentations
+imageio[ffmpeg]
+imageio[pyav]
+tensorboard
+beautifulsoup4
+ftfy
+func_timeout
+deepspeed
+accelerate>=0.25.0
+diffusers>=0.30.1
+transformers==4.47
+av==12.0.0
+gradio
\ No newline at end of file
diff --git a/run.sh b/run.sh
new file mode 100755
index 0000000000000000000000000000000000000000..d7c6e04d30903b49122ef74dc2e7445d0cdfca6e
--- /dev/null
+++ b/run.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+python inference.py \
+ --video_path './test/videos/p7.mp4' \
+ --stride 2 \
+ --out_dir experiments \
+ --radius_scale 1 \
+ --camera 'target' \
+ --mode 'gradual' \
+ --mask \
+ --target_pose 0 -30 0.3 0 0 \
+ --traj_txt 'test/trajs/loop2.txt' \
diff --git a/test/trajs/loop1.txt b/test/trajs/loop1.txt
new file mode 100755
index 0000000000000000000000000000000000000000..b9bb29374e74be0149175c464f69bbe5a066f1ea
--- /dev/null
+++ b/test/trajs/loop1.txt
@@ -0,0 +1,3 @@
+0 2 10 15 12 6 0 -2 -5 -12 -8 -3 0
+0 -3 -10 -20 -30 -25 -17 -10 0
+0 0.02 0.09 0.16 0.25 0.2 0.09 0
\ No newline at end of file
diff --git a/test/trajs/loop2.txt b/test/trajs/loop2.txt
new file mode 100755
index 0000000000000000000000000000000000000000..57a50b4db393d851aada3a1bfee872a61f1b0b8f
--- /dev/null
+++ b/test/trajs/loop2.txt
@@ -0,0 +1,3 @@
+0 2 10 15 12 6 0 -2 -5 -12 -8 -3 0
+0 3 10 20 30 25 17 10 0
+0 0.02 0.09 0.16 0.25 0.28 0.19 0.09 0
\ No newline at end of file
diff --git a/test/videos/0-NNvgaTcVzAG0-r.mp4 b/test/videos/0-NNvgaTcVzAG0-r.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5c0f3681886326cd8cf86617f1d7ec6d8f5123ac
--- /dev/null
+++ b/test/videos/0-NNvgaTcVzAG0-r.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:26d5168e468959d0e4d0f30d47f322209de20c988c1d5315426bc4a9c9ee623a
+size 654822
diff --git a/test/videos/UST-fn-RvhJwMR5S.mp4 b/test/videos/UST-fn-RvhJwMR5S.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c424f7ddc97792a1ba08f1bcdeff3031352ae059
--- /dev/null
+++ b/test/videos/UST-fn-RvhJwMR5S.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e922227b156bbc6a6962a91afbf8b1bc5681495febb53a7203cfc6475bcabcc0
+size 1198452
diff --git a/test/videos/p7.mp4 b/test/videos/p7.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3fbc188564781c3b1adcd857d3a001e588c517e8
--- /dev/null
+++ b/test/videos/p7.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2db72296fb9c7c3b6f7dd832db7722069855f0d7d921e397d51b7e26631b8af
+size 1326612
diff --git a/test/videos/part-2-3.mp4 b/test/videos/part-2-3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a0c62267533797ed90e0c409487403abc67d913c
--- /dev/null
+++ b/test/videos/part-2-3.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:398e66359ddde94853770640b7caaa057e880d007c86b8e52164d6e99e417ec6
+size 665861
diff --git a/test/videos/tUfDESZsQFhdDW9S.mp4 b/test/videos/tUfDESZsQFhdDW9S.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3193425f164a04723e2d339ef3d59bd244ad55b3
--- /dev/null
+++ b/test/videos/tUfDESZsQFhdDW9S.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:894b1c994b6b33ee3ab0a53b5a7d6244ef4b31f41ade0ed2266288841c3ac61f
+size 3166443