diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc4d1194d7dff69574e64c37896ed6629cb5a770
--- /dev/null
+++ b/app.py
@@ -0,0 +1,747 @@
+import os
+import time
+import random
+
+import gradio as gr
+import cv2
+import numpy as np
+from PIL import Image
+
+os.makedirs("./sam2/SAM2-Video-Predictor/checkpoints/", exist_ok=True)
+
+from huggingface_hub import snapshot_download
+
+def download_sam2():
+ snapshot_download(
+ repo_id="facebook/sam2-hiera-large",
+ local_dir="./sam2/SAM2-Video-Predictor/checkpoints/",
+ )
+ print("Download sam2 completed")
+
+def download_refacade():
+ snapshot_download(
+ repo_id="fishze/Refacade",
+ local_dir="./models/",
+ )
+ print("Download refacade completed")
+
+
+# download_sam2()
+
+import torch
+import torch.nn.functional as F
+from decord import VideoReader, cpu
+from moviepy.editor import ImageSequenceClip
+from sam2.build_sam import build_sam2, build_sam2_video_predictor
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+import spaces
+from pipeline import RefacadePipeline
+from vace.models.wan.modules.model_mm import VaceMMModel
+from vace.models.wan.modules.model_tr import VaceWanModel
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from wan.text2video import FlowUniPCMultistepScheduler
+from diffusers.utils import export_to_video, load_image, load_video
+from vae import WanVAE
+
+COLOR_PALETTE = [
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (255, 255, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (255, 128, 0),
+ (128, 0, 255),
+ (0, 128, 255),
+ (128, 255, 0),
+]
+
+video_length = 201
+W = 1024
+H = W
+device = "cuda"
+
+def get_pipe_image_and_video_predictor():
+ vae = WanVAE(
+ vae_pth="./models/vae/Wan2.1_VAE.pth",
+ dtype=torch.float16,
+ )
+
+ pipe_device = "cuda"
+
+ texture_remover = VaceWanModel.from_config(
+ "./models/texture_remover/texture_remover.json"
+ )
+ ckpt = torch.load(
+ "./models/texture_remover/texture_remover.pth",
+ map_location="cpu",
+ )
+ texture_remover.load_state_dict(ckpt)
+ texture_remover = texture_remover.to(dtype=torch.float16, device=pipe_device)
+
+ model = VaceMMModel.from_config(
+ "./models/refacade/refacade.json"
+ )
+ ckpt = torch.load(
+ "./models/refacade/refacade.pth",
+ map_location="cpu",
+ )
+ model.load_state_dict(ckpt)
+ model = model.to(dtype=torch.float16, device=pipe_device)
+
+ sample_scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=1,
+ )
+ pipe = RefacadePipeline(
+ vae=vae,
+ transformer=model,
+ texture_remover=texture_remover,
+ scheduler=sample_scheduler,
+ )
+ pipe.to(pipe_device)
+
+ sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
+ config = "sam2_hiera_l.yaml"
+
+ video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device="cuda")
+ model_sam = build_sam2(config, sam2_checkpoint, device="cuda")
+ model_sam.image_size = 1024
+ image_predictor = SAM2ImagePredictor(sam_model=model_sam)
+
+ return pipe, image_predictor, video_predictor
+
+
+def get_video_info(video_path, video_state):
+ video_state["input_points"] = []
+ video_state["scaled_points"] = []
+ video_state["input_labels"] = []
+ video_state["frame_idx"] = 0
+
+ vr = VideoReader(video_path, ctx=cpu(0))
+ first_frame = vr[0].asnumpy()
+ del vr
+
+ if first_frame.shape[0] > first_frame.shape[1]:
+ W_ = W
+ H_ = int(W_ * first_frame.shape[0] / first_frame.shape[1])
+ else:
+ H_ = H
+ W_ = int(H_ * first_frame.shape[1] / first_frame.shape[0])
+
+ first_frame = cv2.resize(first_frame, (W_, H_))
+ video_state["origin_images"] = np.expand_dims(first_frame, axis=0)
+ video_state["inference_state"] = None
+ video_state["video_path"] = video_path
+ video_state["masks"] = None
+ video_state["painted_images"] = None
+ image = Image.fromarray(first_frame)
+ return image
+
+
+def segment_frame(evt: gr.SelectData, label, video_state):
+ if video_state["origin_images"] is None:
+ return None
+ x, y = evt.index
+ new_point = [x, y]
+ label_value = 1 if label == "Positive" else 0
+
+ video_state["input_points"].append(new_point)
+ video_state["input_labels"].append(label_value)
+ height, width = video_state["origin_images"][0].shape[0:2]
+ scaled_points = []
+ for pt in video_state["input_points"]:
+ sx = pt[0] / width
+ sy = pt[1] / height
+ scaled_points.append([sx, sy])
+
+ video_state["scaled_points"] = scaled_points
+
+ image_predictor.set_image(video_state["origin_images"][0])
+ mask, _, _ = image_predictor.predict(
+ point_coords=video_state["scaled_points"],
+ point_labels=video_state["input_labels"],
+ multimask_output=False,
+ normalize_coords=False,
+ )
+
+ mask = np.squeeze(mask)
+ mask = cv2.resize(mask, (width, height))
+ mask = mask[:, :, None]
+
+ color = (
+ np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32)
+ / 255.0
+ )
+ color = color[None, None, :]
+ org_image = video_state["origin_images"][0].astype(np.float32) / 255.0
+ painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color
+ painted_image = np.uint8(np.clip(painted_image * 255, 0, 255))
+ video_state["painted_images"] = np.expand_dims(painted_image, axis=0)
+ video_state["masks"] = np.expand_dims(mask[:, :, 0], axis=0)
+
+ for i in range(len(video_state["input_points"])):
+ point = video_state["input_points"][i]
+ if video_state["input_labels"][i] == 0:
+ cv2.circle(painted_image, point, radius=3, color=(0, 0, 255), thickness=-1)
+ else:
+ cv2.circle(painted_image, point, radius=3, color=(255, 0, 0), thickness=-1)
+
+ return Image.fromarray(painted_image)
+
+
+def clear_clicks(video_state):
+ video_state["input_points"] = []
+ video_state["input_labels"] = []
+ video_state["scaled_points"] = []
+ video_state["inference_state"] = None
+ video_state["masks"] = None
+ video_state["painted_images"] = None
+ return (
+ Image.fromarray(video_state["origin_images"][0])
+ if video_state["origin_images"] is not None
+ else None
+ )
+
+
+def set_ref_image(ref_img, ref_state):
+ if ref_img is None:
+ return None
+
+ if isinstance(ref_img, Image.Image):
+ img_np = np.array(ref_img)
+ else:
+ img_np = ref_img
+
+ ref_state["origin_image"] = img_np
+ ref_state["input_points"] = []
+ ref_state["input_labels"] = []
+ ref_state["scaled_points"] = []
+ ref_state["mask"] = None
+
+ return Image.fromarray(img_np)
+
+
+def segment_ref_frame(evt: gr.SelectData, label, ref_state):
+ if ref_state["origin_image"] is None:
+ return None
+
+ x, y = evt.index
+ new_point = [x, y]
+ label_value = 1 if label == "Positive" else 0
+
+ ref_state["input_points"].append(new_point)
+ ref_state["input_labels"].append(label_value)
+
+ img = ref_state["origin_image"]
+ h, w = img.shape[:2]
+
+ scaled_points = []
+ for pt in ref_state["input_points"]:
+ sx = pt[0] / w
+ sy = pt[1] / h
+ scaled_points.append([sx, sy])
+ ref_state["scaled_points"] = scaled_points
+
+ image_predictor.set_image(img)
+ mask, _, _ = image_predictor.predict(
+ point_coords=scaled_points,
+ point_labels=ref_state["input_labels"],
+ multimask_output=False,
+ normalize_coords=False,
+ )
+
+ mask = np.squeeze(mask)
+ mask = cv2.resize(mask, (w, h))
+ mask = mask[:, :, None]
+ ref_state["mask"] = mask[:, :, 0]
+
+ color = (
+ np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32)
+ / 255.0
+ )
+ color = color[None, None, :]
+ org_image = img.astype(np.float32) / 255.0
+ painted = (1 - mask * 0.5) * org_image + mask * 0.5 * color
+ painted = np.uint8(np.clip(painted * 255, 0, 255))
+
+ for i in range(len(ref_state["input_points"])):
+ point = ref_state["input_points"][i]
+ if ref_state["input_labels"][i] == 0:
+ cv2.circle(painted, point, radius=3, color=(0, 0, 255), thickness=-1)
+ else:
+ cv2.circle(painted, point, radius=3, color=(255, 0, 0), thickness=-1)
+
+ return Image.fromarray(painted)
+
+
+def clear_ref_clicks(ref_state):
+ ref_state["input_points"] = []
+ ref_state["input_labels"] = []
+ ref_state["scaled_points"] = []
+ ref_state["mask"] = None
+ if ref_state["origin_image"] is None:
+ return None
+ return Image.fromarray(ref_state["origin_image"])
+
+
+@spaces.GPU(duration=40)
+def track_video(n_frames, video_state):
+ input_points = video_state["input_points"]
+ input_labels = video_state["input_labels"]
+ frame_idx = video_state["frame_idx"]
+ obj_id = video_state["obj_id"]
+ scaled_points = video_state["scaled_points"]
+
+ vr = VideoReader(video_state["video_path"], ctx=cpu(0))
+ height, width = vr[0].shape[0:2]
+ images = [vr[i].asnumpy() for i in range(min(len(vr), n_frames))]
+ del vr
+
+ if images[0].shape[0] > images[0].shape[1]:
+ W_ = W
+ H_ = int(W_ * images[0].shape[0] / images[0].shape[1])
+ else:
+ H_ = H
+ W_ = int(H_ * images[0].shape[1] / images[0].shape[0])
+
+ images = [cv2.resize(img, (W_, H_)) for img in images]
+ video_state["origin_images"] = images
+ images = np.array(images)
+
+ sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
+ config = "sam2_hiera_l.yaml"
+ video_predictor_local = build_sam2_video_predictor(
+ config, sam2_checkpoint, device="cuda"
+ )
+
+ inference_state = video_predictor_local.init_state(
+ images=images / 255, device="cuda"
+ )
+
+ if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
+ mask = torch.from_numpy(video_state["masks"][0])[:, :, 0]
+ else:
+ mask = torch.from_numpy(video_state["masks"][0])
+
+ video_predictor_local.add_new_mask(
+ inference_state=inference_state,
+ frame_idx=0,
+ obj_id=obj_id,
+ mask=mask,
+ )
+
+ output_frames = []
+ mask_frames = []
+ color = (
+ np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32)
+ / 255.0
+ )
+ color = color[None, None, :]
+ for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
+ inference_state
+ ):
+ frame = images[out_frame_idx].astype(np.float32) / 255.0
+ mask = np.zeros((H, W, 3), dtype=np.float32)
+ for i, logit in enumerate(out_mask_logits):
+ out_mask = logit.cpu().squeeze().detach().numpy()
+ out_mask = (out_mask[:, :, None] > 0).astype(np.float32)
+ mask += out_mask
+ mask = np.clip(mask, 0, 1)
+ mask = cv2.resize(mask, (W_, H_))
+ mask_frames.append(mask)
+ painted = (1 - mask * 0.5) * frame + mask * 0.5 * color
+ painted = np.uint8(np.clip(painted * 255, 0, 255))
+ output_frames.append(painted)
+
+ video_state["masks"] = mask_frames
+ video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
+ clip = ImageSequenceClip(output_frames, fps=15)
+ clip.write_videofile(
+ video_file, codec="libx264", audio=False, verbose=False, logger=None
+ )
+ print("Tracking done")
+ return video_file, video_state
+
+
+@spaces.GPU(duration=50)
+def inference_and_return_video(
+ dilate_radius,
+ num_inference_steps,
+ guidance_scale,
+ ref_patch_ratio,
+ fg_threshold,
+ seed,
+ video_state,
+ ref_state,
+):
+ if video_state["origin_images"] is None or video_state["masks"] is None:
+ print("No video frames or video masks.")
+ return None, None, None
+
+ if ref_state["origin_image"] is None or ref_state["mask"] is None:
+ print("Reference image or reference mask missing.")
+ return None, None, None
+
+ images = video_state["origin_images"]
+ masks = video_state["masks"]
+
+ video_frames = []
+ mask_frames = []
+ for img, msk in zip(images, masks):
+ if not isinstance(img, np.ndarray):
+ img = np.asarray(img)
+ img_pil = Image.fromarray(img.astype(np.uint8))
+
+ if isinstance(msk, np.ndarray):
+ if msk.ndim == 3:
+ m2 = msk[..., 0]
+ else:
+ m2 = msk
+ else:
+ m2 = np.asarray(msk)
+
+ m2 = (m2 > 0.5).astype(np.uint8) * 255
+ msk_pil = Image.fromarray(m2, mode="L")
+
+ video_frames.append(img_pil)
+ mask_frames.append(msk_pil)
+
+ num_frames = len(video_frames)
+
+ h0, w0 = images[0].shape[:2]
+ if h0 > w0:
+ height = 832
+ width = 480
+ else:
+ height = 480
+ width = 832
+
+ ref_img_np = ref_state["origin_image"]
+ ref_mask_np = ref_state["mask"]
+
+ ref_img_pil = Image.fromarray(ref_img_np.astype(np.uint8))
+ ref_mask_bin = (ref_mask_np > 0.5).astype(np.uint8) * 255
+ ref_mask_pil = Image.fromarray(ref_mask_bin, mode="L")
+
+ pipe.to("cuda")
+ with torch.no_grad():
+ retex_frames, mesh_frames, ref_img_out = pipe(
+ video=video_frames,
+ mask=mask_frames,
+ reference_image=ref_img_pil,
+ reference_mask=ref_mask_pil,
+ conditioning_scale=1.0,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ dilate_radius=int(dilate_radius),
+ num_inference_steps=int(num_inference_steps),
+ guidance_scale=float(guidance_scale),
+ reference_patch_ratio=float(ref_patch_ratio),
+ fg_thresh=float(fg_threshold),
+ generator=torch.Generator(device="cuda").manual_seed(seed),
+ return_dict=False,
+ )
+
+ retex_frames_uint8 = (np.clip(retex_frames[0], 0.0, 1.0) * 255).astype(np.uint8)
+
+ mesh_frames_uint8 = (np.clip(mesh_frames[0], 0.0, 1.0) * 255).astype(np.uint8)
+
+
+ retex_output_frames = [frame for frame in retex_frames_uint8]
+ mesh_output_frames = [frame for frame in mesh_frames_uint8]
+
+ if ref_img_out.dtype != np.uint8:
+ ref_img_out = (np.clip(ref_img_out, 0.0, 1.0) * 255).astype(np.uint8)
+
+ retex_video_file = f"/tmp/{time.time()}-{random.random()}-refacade_output.mp4"
+ retex_clip = ImageSequenceClip(retex_output_frames, fps=16)
+ retex_clip.write_videofile(
+ retex_video_file, codec="libx264", audio=False, verbose=False, logger=None
+ )
+
+ mesh_video_file = f"/tmp/{time.time()}-{random.random()}-mesh_output.mp4"
+ mesh_clip = ImageSequenceClip(mesh_output_frames, fps=16)
+ mesh_clip.write_videofile(
+ mesh_video_file, codec="libx264", audio=False, verbose=False, logger=None
+ )
+
+ ref_image_to_show = ref_img_out
+
+ return retex_video_file, mesh_video_file, ref_image_to_show
+
+
+text = """
+
+ Refaçade Video Retexture Demo
+
+
+ Video mask from SAM2, Reference mask from SAM2 image clicks, RefacadePipeline for object retexture task.
+
+"""
+
+pipe, image_predictor, video_predictor = get_pipe_image_and_video_predictor()
+
+with gr.Blocks() as demo:
+ video_state = gr.State(
+ {
+ "origin_images": None,
+ "inference_state": None,
+ "masks": None,
+ "painted_images": None,
+ "video_path": None,
+ "input_points": [],
+ "scaled_points": [],
+ "input_labels": [],
+ "frame_idx": 0,
+ "obj_id": 1,
+ }
+ )
+
+ ref_state = gr.State(
+ {
+ "origin_image": None,
+ "input_points": [],
+ "input_labels": [],
+ "scaled_points": [],
+ "mask": None,
+ }
+ )
+
+ gr.Markdown(f"{text}
")
+
+ with gr.Column():
+ video_input = gr.Video(label="Upload Video", elem_id="my-video1")
+ get_info_btn = gr.Button("Extract First Frame", elem_id="my-btn")
+
+ gr.Examples(
+ examples=[
+ ["./examples/1.mp4"],
+ ["./examples/2.mp4"],
+ ["./examples/3.mp4"],
+ ["./examples/4.mp4"],
+ ["./examples/5.mp4"],
+ ["./examples/6.mp4"],
+ ],
+ inputs=[video_input],
+ label="You can upload or choose a source video below to retexture.",
+ elem_id="my-btn2"
+ )
+
+ image_output = gr.Image(
+ label="First Frame Segmentation",
+ interactive=True,
+ elem_id="my-video",
+ )
+
+ demo.css = """
+ #my-btn {
+ width: 60% !important;
+ margin: 0 auto;
+ }
+ #my-video1 {
+ width: 60% !important;
+ height: 35% !important;
+ margin: 0 auto;
+ }
+ #my-video {
+ width: 60% !important;
+ height: 35% !important;
+ margin: 0 auto;
+ }
+ #my-md {
+ margin: 0 auto;
+ }
+ #my-btn2 {
+ width: 60% !important;
+ margin: 0 auto;
+ }
+ #my-btn2 button {
+ width: 120px !important;
+ max-width: 120px !important;
+ min-width: 120px !important;
+ height: 70px !important;
+ max-height: 70px !important;
+ min-height: 70px !important;
+ margin: 8px !important;
+ border-radius: 8px !important;
+ overflow: hidden !important;
+ white-space: normal !important;
+ }
+ #my-btn3 {
+ width: 60% !important;
+ margin: 0 auto;
+ }
+ #ref_title {
+ text-align: center;
+ }
+ #ref-image {
+ width: 60% !important;
+ height: 35% !important;
+ margin: 0 auto;
+ }
+ #ref-mask {
+ width: 60% !important;
+ height: 35% !important;
+ margin: 0 auto;
+ }
+ #mesh-row {
+ width: 60% !important;
+ margin: 0 auto;
+ }
+ """
+
+ with gr.Row(elem_id="my-btn"):
+ point_prompt = gr.Radio(
+ ["Positive", "Negative"], label="Click Type", value="Positive"
+ )
+ clear_btn = gr.Button("Clear All Clicks")
+
+ with gr.Row(elem_id="my-btn"):
+ n_frames_slider = gr.Slider(
+ minimum=1, maximum=201, value=81, step=1, label="Tracking Frames (4N+1)"
+ )
+ track_btn = gr.Button("Tracking")
+ video_output = gr.Video(label="Tracking Result", elem_id="my-video")
+
+ gr.Markdown("Reference Image & Mask (SAM2 Points)", elem_id="ref_title")
+
+ ref_image_input = gr.Image(
+ label="Upload Reference Image", elem_id="ref-image", interactive=True
+ )
+ gr.Examples(
+ examples=[
+ ["./examples/reference_image/1.png"],
+ ["./examples/reference_image/2.png"],
+ ["./examples/reference_image/3.png"],
+ ["./examples/reference_image/4.png"],
+ ["./examples/reference_image/5.png"],
+ ["./examples/reference_image/6.png"],
+ ["./examples/reference_image/7.png"],
+ ["./examples/reference_image/8.png"],
+ ["./examples/reference_image/9.png"],
+ ],
+ inputs=[ref_image_input],
+ label="You can upload or choose a reference image below to retexture.",
+ elem_id="my-btn3"
+ )
+ ref_image_display = gr.Image(
+ label="Reference Mask Segmentation",
+ elem_id="ref-mask",
+ interactive=True,
+ )
+
+ with gr.Row(elem_id="my-btn"):
+ ref_point_prompt = gr.Radio(
+ ["Positive", "Negative"], label="Ref Click Type", value="Positive"
+ )
+ ref_clear_btn = gr.Button("Clear Ref Clicks")
+
+ with gr.Column(elem_id="my-btn"):
+
+ dilate_radius_slider = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=3,
+ step=1,
+ label="Mask Dilation Radius",
+ )
+ inference_steps_slider = gr.Slider(
+ minimum=1,
+ maximum=50,
+ value=20,
+ step=1,
+ label="Num Inference Steps",
+ )
+ guidance_slider = gr.Slider(
+ minimum=1.0,
+ maximum=3.0,
+ value=1.5,
+ step=0.1,
+ label="Guidance Scale",
+ )
+ ref_patch_slider = gr.Slider(
+ minimum=0.05,
+ maximum=1.0,
+ value=0.1,
+ step=0.05,
+ label="Reference Patch Ratio",
+ )
+ fg_threshold_slider = gr.Slider(
+ minimum=0.7,
+ maximum=1.0,
+ value=0.8,
+ step=0.01,
+ label="Jigsaw Patches' Foreground Coverage Threshold",
+ )
+ seed_slider = gr.Slider(
+ minimum=0,
+ maximum=2147483647,
+ value=42,
+ step=1,
+ label="Seed",
+ )
+
+ remove_btn = gr.Button("Retexture", elem_id="my-btn")
+
+ with gr.Row(elem_id="mesh-row"):
+ mesh_video = gr.Video(label="Untextured Object")
+ ref_image_final = gr.Image(
+ label="Jigsawed Reference Image",
+ interactive=False,
+ )
+
+ remove_video = gr.Video(label="Retexture Results", elem_id="my-video")
+
+ remove_btn.click(
+ inference_and_return_video,
+ inputs=[
+ dilate_radius_slider,
+ inference_steps_slider,
+ guidance_slider,
+ ref_patch_slider,
+ fg_threshold_slider,
+ seed_slider,
+ video_state,
+ ref_state,
+ ],
+ outputs=[remove_video, mesh_video, ref_image_final],
+ )
+
+ get_info_btn.click(
+ get_video_info,
+ inputs=[video_input, video_state],
+ outputs=image_output,
+ )
+
+ image_output.select(
+ fn=segment_frame,
+ inputs=[point_prompt, video_state],
+ outputs=image_output,
+ )
+
+ clear_btn.click(clear_clicks, inputs=video_state, outputs=image_output)
+
+ track_btn.click(
+ track_video,
+ inputs=[n_frames_slider, video_state],
+ outputs=[video_output, video_state],
+ )
+
+ ref_image_input.change(
+ set_ref_image,
+ inputs=[ref_image_input, ref_state],
+ outputs=ref_image_display,
+ )
+ ref_image_display.select(
+ fn=segment_ref_frame,
+ inputs=[ref_point_prompt, ref_state],
+ outputs=ref_image_display,
+ )
+ ref_clear_btn.click(
+ clear_ref_clicks, inputs=ref_state, outputs=ref_image_display
+ )
+
+demo.launch()
+
diff --git a/pipeline.py b/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcac726ae8483973a267d307011265841652eec8
--- /dev/null
+++ b/pipeline.py
@@ -0,0 +1,718 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+import PIL.Image
+import torch
+import math
+import random
+import numpy as np
+import torch.nn.functional as F
+from typing import Tuple
+from PIL import Image
+
+from vae import WanVAE
+from vace.models.wan.modules.model_mm import VaceMMModel
+from vace.models.wan.modules.model_tr import VaceWanModel
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import PipelineImageInput
+from diffusers.loaders import WanLoraLoaderMixin
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.utils import BaseOutput
+from dataclasses import dataclass
+
+
+@dataclass
+class RefacadePipelineOutput(BaseOutput):
+ frames: torch.Tensor
+ meshes: torch.Tensor
+ ref_img: torch.Tensor
+
+
+logger = logging.get_logger(__name__)
+
+
+@torch.no_grad()
+def _pad_to_multiple(x: torch.Tensor, multiple: int, mode: str = "reflect"):
+ H, W = x.shape[-2], x.shape[-1]
+ pad_h = (multiple - H % multiple) % multiple
+ pad_w = (multiple - W % multiple) % multiple
+ pad = (0, pad_w, 0, pad_h)
+ if pad_h or pad_w:
+ x = F.pad(x, pad, mode=mode)
+ return x, pad
+
+
+@torch.no_grad()
+def _unpad(x: torch.Tensor, pad):
+ l, r, t, b = pad
+ H, W = x.shape[-2], x.shape[-1]
+ return x[..., t:H - b if b > 0 else H, l:W - r if r > 0 else W]
+
+
+@torch.no_grad()
+def _resize(x: torch.Tensor, size: tuple, is_mask: bool):
+ mode = "nearest" if is_mask else "bilinear"
+ if is_mask:
+ return F.interpolate(x, size=size, mode=mode)
+ else:
+ return F.interpolate(x, size=size, mode=mode, align_corners=False)
+
+
+@torch.no_grad()
+def _center_scale_foreground_to_canvas(
+ x_f: torch.Tensor,
+ m_f: torch.Tensor,
+ target_hw: tuple,
+ bg_value: float = 1.0,
+):
+ C, H, W = x_f.shape
+ H2, W2 = target_hw
+ device = x_f.device
+ ys, xs = (m_f > 0.5).nonzero(as_tuple=True)
+ canvas = torch.full((C, H2, W2), bg_value, dtype=x_f.dtype, device=device)
+ mask_canvas = torch.zeros((1, H2, W2), dtype=x_f.dtype, device=device)
+ if ys.numel() == 0:
+ return canvas, mask_canvas
+
+ y0, y1 = ys.min().item(), ys.max().item()
+ x0, x1 = xs.min().item(), xs.max().item()
+ crop_img = x_f[:, y0:y1 + 1, x0:x1 + 1]
+ crop_msk = m_f[y0:y1 + 1, x0:x1 + 1].unsqueeze(0)
+ hc, wc = crop_msk.shape[-2], crop_msk.shape[-1]
+ s = min(H2 / max(1, hc), W2 / max(1, wc))
+ Ht = max(1, min(H2, int(math.floor(hc * s))))
+ Wt = max(1, min(W2, int(math.floor(wc * s))))
+ crop_img_up = _resize(crop_img.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0)
+ crop_msk_up = _resize(crop_msk.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0)
+ crop_msk_up = (crop_msk_up > 0.5).to(crop_msk_up.dtype)
+
+ top = (H2 - Ht) // 2
+ left = (W2 - Wt) // 2
+ canvas[:, top:top + Ht, left:left + Wt] = crop_img_up
+ mask_canvas[:, top:top + Ht, left:left + Wt] = crop_msk_up
+ return canvas, mask_canvas
+
+
+@torch.no_grad()
+def _sample_patch_size_from_hw(
+ H: int,
+ W: int,
+ ratio: float = 0.2,
+ min_px: int = 16,
+ max_px: Optional[int] = None,
+) -> int:
+ r = ratio
+ raw = r * min(H, W)
+ if max_px is None:
+ max_px = min(192, min(H, W))
+ P = int(round(raw))
+ P = max(min_px, min(P, max_px))
+ P = int(P)
+ return P
+
+
+@torch.no_grad()
+def _masked_patch_pack_to_center_rectangle(
+ x_f: torch.Tensor,
+ m_f: torch.Tensor,
+ patch: int,
+ fg_thresh: float = 0.8,
+ bg_value: float = 1.0,
+ min_patches: int = 4,
+ flip_prob: float = 0.5,
+ use_morph_erode: bool = False,
+):
+
+ C, H, W = x_f.shape
+ device = x_f.device
+ P = int(patch)
+
+ x_pad, pad = _pad_to_multiple(x_f, P, mode="reflect")
+ l, r, t, b = pad
+ H2, W2 = x_pad.shape[-2], x_pad.shape[-1]
+ m_pad = F.pad(m_f.unsqueeze(0).unsqueeze(0), (l, r, t, b), mode="constant", value=0.0).squeeze(0)
+
+ cs_img, cs_msk = _center_scale_foreground_to_canvas(x_pad, m_pad.squeeze(0), (H2, W2), bg_value)
+ if (cs_msk > 0.5).sum() == 0:
+ out_img = _unpad(cs_img, pad).clamp_(-1, 1)
+ out_msk = _unpad(cs_msk, pad).clamp_(0, 1)
+ return out_img, out_msk, True
+
+ m_eff = cs_msk
+ if use_morph_erode:
+ erode_px = int(max(1, min(6, round(P * 0.03))))
+ m_eff = 1.0 - F.max_pool2d(1.0 - cs_msk, kernel_size=2 * erode_px + 1, stride=1, padding=erode_px)
+
+ x_pad2, pad2 = _pad_to_multiple(cs_img, P, mode="reflect")
+ m_pad2 = F.pad(m_eff, pad2, mode="constant", value=0.0)
+ H3, W3 = x_pad2.shape[-2], x_pad2.shape[-1]
+
+ m_pool = F.avg_pool2d(m_pad2, kernel_size=P, stride=P).view(-1)
+
+ base_thr = float(fg_thresh)
+ thr_candidates = [base_thr, max(base_thr - 0.05, 0.75), max(base_thr - 0.10, 0.60)]
+
+ x_unf = F.unfold(x_pad2.unsqueeze(0), kernel_size=P, stride=P)
+ N = x_unf.shape[-1]
+
+ sel = None
+ for thr in thr_candidates:
+ idx = (m_pool >= (thr - 1e-6)).nonzero(as_tuple=False).squeeze(1)
+ if idx.numel() >= min_patches:
+ sel = idx
+ break
+ if sel is None:
+ img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1)
+ msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1)
+ return img_fallback, msk_fallback, True
+
+ sel = sel.to(device=device, dtype=torch.long)
+ sel = sel[(sel >= 0) & (sel < N)]
+ if sel.numel() == 0:
+ img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1)
+ msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1)
+ return img_fallback, msk_fallback, True
+
+ perm = torch.randperm(sel.numel(), device=device, dtype=torch.long)
+ sel = sel[perm]
+ chosen_x = x_unf[:, :, sel]
+ K = chosen_x.shape[-1]
+ if K == 0:
+ img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1)
+ msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1)
+ return img_fallback, msk_fallback, True
+
+ if flip_prob > 0:
+ cx4 = chosen_x.view(1, C, P, P, K)
+ do_flip = (torch.rand(K, device=device) < flip_prob)
+ coin = (torch.rand(K, device=device) < 0.5)
+ flip_h = do_flip & coin
+ flip_v = do_flip & (~coin)
+ if flip_h.any():
+ cx4[..., flip_h] = cx4[..., flip_h].flip(dims=[3])
+ if flip_v.any():
+ cx4[..., flip_v] = cx4[..., flip_v].flip(dims=[2])
+ chosen_x = cx4.view(1, C * P * P, K)
+
+ max_cols = max(1, W3 // P)
+ max_rows = max(1, H3 // P)
+ capacity = max_rows * max_cols
+ K_cap = min(K, capacity)
+ cols = int(max(1, min(int(math.floor(math.sqrt(K_cap))), max_cols)))
+ rows_full = min(max_rows, K_cap // cols)
+ K_used = rows_full * cols
+ if K_used == 0:
+ img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1)
+ msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1)
+ return img_fallback, msk_fallback, True
+
+ chosen_x = chosen_x[:, :, :K_used]
+ rect_unf = torch.full((1, C * P * P, rows_full * cols), bg_value, device=device, dtype=x_f.dtype)
+ rect_unf[:, :, :K_used] = chosen_x
+ rect = F.fold(rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0)
+
+ ones_patch = torch.ones((1, 1 * P * P, K_used), device=device, dtype=x_f.dtype)
+ mask_rect_unf = torch.zeros((1, 1 * P * P, rows_full * cols), device=device, dtype=x_f.dtype)
+ mask_rect_unf[:, :, :K_used] = ones_patch
+ rect_mask = F.fold(mask_rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0)
+
+ Hr, Wr = rect.shape[-2], rect.shape[-1]
+ s = min(H3 / max(1, Hr), W3 / max(1, Wr))
+ Ht = min(max(1, int(math.floor(Hr * s))), H3)
+ Wt = min(max(1, int(math.floor(Wr * s))), W3)
+
+ rect_up = _resize(rect.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0)
+ rect_mask_up = _resize(rect_mask.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0)
+
+ canvas_x = torch.full((C, H3, W3), bg_value, device=device, dtype=x_f.dtype)
+ canvas_m = torch.zeros((1, H3, W3), device=device, dtype=x_f.dtype)
+ top, left = (H3 - Ht) // 2, (W3 - Wt) // 2
+ canvas_x[:, top:top + Ht, left:left + Wt] = rect_up
+ canvas_m[:, top:top + Ht, left:left + Wt] = rect_mask_up
+
+ out_img = _unpad(_unpad(canvas_x, pad2), pad).clamp_(-1, 1)
+ out_msk = _unpad(_unpad(canvas_m, pad2), pad).clamp_(0, 1)
+ return out_img, out_msk, False
+
+
+@torch.no_grad()
+def _compose_centered_foreground(x_f: torch.Tensor, m_f3: torch.Tensor, target_hw: Tuple[int, int], bg_value: float = 1.0):
+ m_bin = (m_f3 > 0.5).float().mean(dim=0)
+ m_bin = (m_bin > 0.5).float()
+ return _center_scale_foreground_to_canvas(x_f, m_bin, target_hw, bg_value)
+
+class RefacadePipeline(DiffusionPipeline, WanLoraLoaderMixin):
+
+ model_cpu_offload_seq = "texture_remover->transformer->vae"
+
+ def __init__(
+ self,
+ vae,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer: VaceMMModel = None,
+ texture_remover: VaceWanModel = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ texture_remover=texture_remover,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor_temporal = 4
+ self.vae_scale_factor_spatial = 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.empty_embedding = torch.load(
+ "/ms/AIGC/huangyouze/fish_pipeline/stage10_t5_encode/image/empty.pt",
+ map_location="cpu"
+ )
+ self.negative_embedding = torch.load(
+ "/ms/AIGC/huangyouze/fish_pipeline/stage10_t5_encode/image/negative.pt",
+ map_location="cpu"
+ )
+
+ def vace_encode_masks(self, masks: torch.Tensor):
+ masks = masks[:, :1, :, :, :]
+ B, C, D, H, W = masks.shape
+ patch_h, patch_w = self.vae_scale_factor_spatial, self.vae_scale_factor_spatial
+ stride_t = self.vae_scale_factor_temporal
+ patch_count = patch_h * patch_w
+ new_D = (D + stride_t - 1) // stride_t
+ new_H = 2 * (H // (patch_h * 2))
+ new_W = 2 * (W // (patch_w * 2))
+ masks = masks[:, 0]
+ masks = masks.view(B, D, new_H, patch_h, new_W, patch_w)
+ masks = masks.permute(0, 3, 5, 1, 2, 4)
+ masks = masks.reshape(B, patch_count, D, new_H, new_W)
+ masks = F.interpolate(
+ masks,
+ size=(new_D, new_H, new_W),
+ mode="nearest-exact"
+ )
+ return masks
+
+ def preprocess_conditions(
+ self,
+ video: Optional[List[PipelineImageInput]] = None,
+ mask: Optional[List[PipelineImageInput]] = None,
+ reference_image: Optional[PIL.Image.Image] = None,
+ reference_mask: Optional[PIL.Image.Image] = None,
+ batch_size: int = 1,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ reference_patch_ratio: float = 0.2,
+ fg_thresh: float = 0.9,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+
+ base = self.vae_scale_factor_spatial * 2
+ video_height, video_width = self.video_processor.get_default_height_width(video[0])
+
+ if video_height * video_width > height * width:
+ scale_w = width / video_width
+ scale_h = height / video_height
+ video_height, video_width = int(video_height * scale_h), int(video_width * scale_w)
+
+ if video_height % base != 0 or video_width % base != 0:
+ logger.warning(
+ f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. "
+ )
+ video_height = (video_height // base) * base
+ video_width = (video_width // base) * base
+
+ assert video_height * video_width <= height * width
+
+ video = self.video_processor.preprocess_video(video, video_height, video_width)
+ image_size = (video_height, video_width)
+
+ mask = self.video_processor.preprocess_video(mask, video_height, video_width)
+ mask = torch.clamp((mask + 1) / 2, min=0, max=1)
+
+ video = video.to(dtype=dtype, device=device)
+ mask = mask.to(dtype=dtype, device=device)
+
+ if reference_image is None:
+ raise ValueError("reference_image must be provided when using IMAGE_CONTROL mode.")
+
+ if isinstance(reference_image, (list, tuple)):
+ ref_img_pil = reference_image[0]
+ else:
+ ref_img_pil = reference_image
+
+ if reference_mask is not None and isinstance(reference_mask, (list, tuple)):
+ ref_mask_pil = reference_mask[0]
+ else:
+ ref_mask_pil = reference_mask
+
+ ref_img_t = self.video_processor.preprocess(ref_img_pil, image_size[0], image_size[1])
+ if ref_img_t.dim() == 4 and ref_img_t.shape[0] == 1:
+ ref_img_t = ref_img_t[0]
+ if ref_img_t.shape[0] == 1:
+ ref_img_t = ref_img_t.repeat(3, 1, 1)
+ ref_img_t = ref_img_t.to(dtype=dtype, device=device)
+
+ H, W = image_size
+ if ref_mask_pil is not None:
+ if not isinstance(ref_mask_pil, Image.Image):
+ ref_mask_pil = Image.fromarray(np.array(ref_mask_pil))
+ ref_mask_pil = ref_mask_pil.convert("L")
+ ref_mask_pil = ref_mask_pil.resize((W, H), Image.NEAREST)
+ mask_arr = np.array(ref_mask_pil)
+ m = torch.from_numpy(mask_arr).float() / 255.0
+ m = (m > 0.5).float()
+ ref_msk3 = m.unsqueeze(0).repeat(3, 1, 1)
+ else:
+ ref_msk3 = torch.ones(3, H, W, dtype=dtype)
+
+ ref_msk3 = ref_msk3.to(dtype=dtype, device=device)
+
+ if math.isclose(reference_patch_ratio, 1.0, rel_tol=1e-6, abs_tol=1e-6):
+ cs_img, cs_m = _compose_centered_foreground(
+ x_f=ref_img_t,
+ m_f3=ref_msk3,
+ target_hw=image_size,
+ bg_value=1.0,
+ )
+ ref_img_out = cs_img
+ ref_mask_out = cs_m
+ else:
+ patch = _sample_patch_size_from_hw(
+ H=image_size[0],
+ W=image_size[1],
+ ratio=reference_patch_ratio,
+ )
+
+ m_bin = (ref_msk3 > 0.5).float().mean(dim=0)
+ m_bin = (m_bin > 0.5).float()
+ reshuffled, reshuf_mask, used_fb = _masked_patch_pack_to_center_rectangle(
+ x_f=ref_img_t,
+ m_f=m_bin,
+ patch=patch,
+ fg_thresh=fg_thresh,
+ bg_value=1.0,
+ min_patches=4,
+ )
+
+ ref_img_out = reshuffled
+ ref_mask_out = reshuf_mask
+
+ B = video.shape[0]
+ if batch_size is not None:
+ B = batch_size
+
+ ref_image = ref_img_out.unsqueeze(0).unsqueeze(2).expand(B, -1, -1, -1, -1).contiguous()
+ ref_mask = ref_mask_out.unsqueeze(0).unsqueeze(2).expand(B, 3, -1, -1, -1).contiguous()
+
+ ref_image = ref_image.to(dtype=dtype, device=device)
+ ref_mask = ref_mask.to(dtype=dtype, device=device)
+
+ return video[:, :, :num_frames], mask[:, :, :num_frames], ref_image, ref_mask
+
+ @torch.no_grad()
+ def texture_remove(self, foreground_latent):
+ sample_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=1)
+ text_embedding = torch.zeros(
+ [256, 4096],
+ device=foreground_latent.device,
+ dtype=foreground_latent.dtype
+ )
+ context = text_embedding.unsqueeze(0).expand(
+ foreground_latent.shape[0], -1, -1
+ ).to(foreground_latent.device)
+ sample_scheduler.set_timesteps(3, device=foreground_latent.device)
+ timesteps = sample_scheduler.timesteps
+ noise = torch.randn_like(
+ foreground_latent,
+ dtype=foreground_latent.dtype,
+ device=foreground_latent.device
+ )
+ seq_len = math.ceil(
+ noise.shape[2] * noise.shape[3] * noise.shape[4] / 4
+ )
+ latents = noise
+ arg_c = {"context": context, "seq_len": seq_len}
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
+ for _, t in enumerate(timesteps):
+ timestep = torch.stack([t]).to(foreground_latent.device)
+ noise_pred_cond = self.texture_remover(
+ latents,
+ t=timestep,
+ vace_context=foreground_latent,
+ vace_context_scale=1,
+ **arg_c
+ )[0]
+ temp_x0 = sample_scheduler.step(
+ noise_pred_cond, t, latents, return_dict=False
+ )[0]
+ latents = temp_x0
+ return latents
+
+ def dilate_mask_hw(self, mask: torch.Tensor, radius: int = 3) -> torch.Tensor:
+ B, C, F_, H, W = mask.shape
+ k = 2 * radius + 1
+ mask_2d = mask.permute(0, 2, 1, 3, 4).reshape(B * F_, C, H, W)
+ kernel = torch.ones(
+ (C, 1, k, k),
+ device=mask.device,
+ dtype=mask.dtype
+ )
+ dilated_2d = F.conv2d(
+ mask_2d,
+ weight=kernel,
+ bias=None,
+ stride=1,
+ padding=radius,
+ groups=C
+ )
+ dilated_2d = (dilated_2d > 0).to(mask.dtype)
+ dilated = dilated_2d.view(B, F_, C, H, W).permute(0, 2, 1, 3, 4)
+ return dilated
+
+ def prepare_vace_latents(
+ self,
+ dilate_radius: int,
+ video: torch.Tensor,
+ mask: torch.Tensor,
+ reference_image: Optional[torch.Tensor] = None,
+ reference_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+
+ vae_dtype = self.vae.dtype
+ video = video.to(dtype=vae_dtype)
+ mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
+ mask_clone = mask.clone()
+ mask = self.dilate_mask_hw(mask, dilate_radius)
+ inactive = video * (1 - mask)
+ reactive = video * mask_clone
+ reactive_latent = self.vae.encode(reactive)
+ mesh_latent = self.texture_remove(reactive_latent)
+
+ inactive_latent = self.vae.encode(inactive)
+ ref_latent = self.vae.encode(reference_image)
+ neg_ref_latent = self.vae.encode(torch.ones_like(reference_image))
+
+ reference_mask = torch.where(reference_mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
+ mask = self.vace_encode_masks(mask)
+ ref_mask = self.vace_encode_masks(reference_mask)
+
+ return inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask
+
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ video: Optional[PipelineImageInput] = None,
+ mask: Optional[PipelineImageInput] = None,
+ reference_image: Optional[PipelineImageInput] = None,
+ reference_mask: Optional[PipelineImageInput] = None,
+ conditioning_scale: float = 1.0,
+ dilate_radius: int = 3,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 20,
+ guidance_scale: float = 1.5,
+ num_videos_per_prompt: Optional[int] = 1,
+ reference_patch_ratio: float = 0.2,
+ fg_thresh: float = 0.9,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ ):
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+
+ self._guidance_scale = guidance_scale
+
+ device = self._execution_device
+ batch_size = 1
+
+ vae_dtype = self.vae.dtype
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ video, mask, reference_image, reference_mask = self.preprocess_conditions(
+ video,
+ mask,
+ reference_image,
+ reference_mask,
+ batch_size,
+ height,
+ width,
+ num_frames,
+ reference_patch_ratio,
+ fg_thresh,
+ torch.float16,
+ device,
+ )
+
+ inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask = self.prepare_vace_latents(dilate_radius, video, mask, reference_image, reference_mask, device)
+ c = torch.cat([inactive_latent, mesh_latent, mask], dim=1)
+ c1 = torch.cat([ref_latent, ref_mask], dim=1)
+ c1_negative = torch.cat(
+ [neg_ref_latent, torch.zeros_like(ref_mask)],
+ dim=1
+ )
+
+ num_channels_latents = 16
+ noise = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float16,
+ device,
+ generator,
+ latents,
+ )
+
+ latents_cond = torch.cat([ref_latent, noise], dim=2)
+ latents_uncond = torch.cat([neg_ref_latent, noise], dim=2)
+
+ seq_len = math.ceil(
+ latents_cond.shape[2] *
+ latents_cond.shape[3] *
+ latents_cond.shape[4] / 4
+ )
+ seq_len_ref = math.ceil(
+ ref_latent.shape[2] *
+ ref_latent.shape[3] *
+ ref_latent.shape[4] / 4
+ )
+ context = self.empty_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device)
+ context_neg = self.negative_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device)
+ arg_c = {
+ "context": context,
+ "seq_len": seq_len,
+ "seq_len_ref": seq_len_ref
+ }
+ arg_c_null = {
+ "context": context_neg,
+ "seq_len": seq_len,
+ "seq_len_ref": seq_len_ref
+ }
+
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ self._current_timestep = t
+ timestep = t.expand(batch_size)
+
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
+ noise_pred = self.transformer(
+ latents_cond,
+ t=timestep,
+ vace_context=c,
+ ref_context=c1,
+ vace_context_scale=conditioning_scale,
+ **arg_c,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ latents_uncond,
+ t=timestep,
+ vace_context=c,
+ ref_context=c1_negative,
+ vace_context_scale=0,
+ **arg_c_null,
+ )[0]
+ noise_pred = (noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)).unsqueeze(0)
+ temp_x0 = self.scheduler.step(noise_pred[:, :, 1:],
+ t,
+ latents_cond[:, :, 1:],
+ return_dict=False)[0]
+ latents_cond = torch.cat([ref_latent, temp_x0], dim=2)
+ latents_uncond = torch.cat([neg_ref_latent, temp_x0], dim=2)
+ progress_bar.update()
+
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = temp_x0
+ latents = latents.to(vae_dtype)
+ video = self.vae.decode(latents)
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ mesh = self.vae.decode(mesh_latent.to(vae_dtype))
+ mesh = self.video_processor.postprocess_video(mesh, output_type=output_type)
+ ref_img = reference_image.cpu().squeeze(0).squeeze(1).permute(1, 2, 0).numpy()
+ ref_img = ((ref_img+1)*255/2).astype(np.uint8)
+ else:
+ video = temp_x0
+ mesh = mesh_latent
+ ref_img = ref_latent
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video, mesh, ref_img)
+
+ return RefacadePipelineOutput(frames=video, meshes=mesh, ref_img=ref_img)
diff --git a/sam2/SAM2-Video-Predictor/checkpoints/README.md b/sam2/SAM2-Video-Predictor/checkpoints/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..87274582326ed60c599f54b462b76da02085f813
--- /dev/null
+++ b/sam2/SAM2-Video-Predictor/checkpoints/README.md
@@ -0,0 +1,58 @@
+---
+license: apache-2.0
+pipeline_tag: mask-generation
+library_name: sam2
+---
+
+Repository for SAM 2: Segment Anything in Images and Videos, a foundation model towards solving promptable visual segmentation in images and videos from FAIR. See the [SAM 2 paper](https://arxiv.org/abs/2408.00714) for more information.
+
+The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/).
+
+## Usage
+
+For image prediction:
+
+```python
+import torch
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
+
+with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
+ predictor.set_image()
+ masks, _, _ = predictor.predict()
+```
+
+For video prediction:
+
+```python
+import torch
+from sam2.sam2_video_predictor import SAM2VideoPredictor
+
+predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
+
+with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
+ state = predictor.init_state()
+
+ # add new prompts and instantly get the output on the same frame
+ frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ):
+
+ # propagate the prompts to get masklets throughout the video
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
+ ...
+```
+
+Refer to the [demo notebooks](https://github.com/facebookresearch/segment-anything-2/tree/main/notebooks) for details.
+
+### Citation
+
+To cite the paper, model, or software, please use the below:
+```
+@article{ravi2024sam2,
+ title={SAM 2: Segment Anything in Images and Videos},
+ author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
+ journal={arXiv preprint arXiv:2408.00714},
+ url={https://arxiv.org/abs/2408.00714},
+ year={2024}
+}
+```
\ No newline at end of file
diff --git a/sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_l.yaml b/sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2b92b5b3a7392b627a74c46f29195ffdaab82d82
--- /dev/null
+++ b/sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_l.yaml
@@ -0,0 +1,117 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 144
+ num_heads: 2
+ stages: [2, 6, 36, 4]
+ global_att_blocks: [23, 33, 43]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ window_spec: [8, 4, 16, 8]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [1152, 576, 288, 144]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
\ No newline at end of file
diff --git a/sam2/__init__.py b/sam2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d05adb1c40ab94d5eaa10da430ce9e9fd0e62bc9
--- /dev/null
+++ b/sam2/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from hydra import initialize
+
+from .build_sam import load_model
+
+initialize("configs", version_base="1.2")
diff --git a/sam2/__pycache__/__init__.cpython-311.pyc b/sam2/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4027f5036ad5a33fae0b2858d8863728e8b23b7e
Binary files /dev/null and b/sam2/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sam2/__pycache__/build_sam.cpython-311.pyc b/sam2/__pycache__/build_sam.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de0ffe3d050017efdd8e10d642681c2ae2eab275
Binary files /dev/null and b/sam2/__pycache__/build_sam.cpython-311.pyc differ
diff --git a/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc b/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66ca05f50322d7753a239f062900daabe18a6f75
Binary files /dev/null and b/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc differ
diff --git a/sam2/__pycache__/sam2_video_predictor.cpython-311.pyc b/sam2/__pycache__/sam2_video_predictor.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34b84f2c5406a66130071c7480ce5aa94f88eccb
Binary files /dev/null and b/sam2/__pycache__/sam2_video_predictor.cpython-311.pyc differ
diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..37c12f48cc0e218fb1634431590bfaf93a1151e9
--- /dev/null
+++ b/sam2/automatic_mask_generator.py
@@ -0,0 +1,434 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area # type: ignore
+
+from sam2.modeling.sam2_base import SAM2Base
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from sam2.utils.amg import (
+ MaskData,
+ area_from_rle,
+ batch_iterator,
+ batched_mask_to_box,
+ box_xyxy_to_xywh,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ coco_encode_rle,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ mask_to_rle_pytorch,
+ remove_small_regions,
+ rle_to_mask,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+ uncrop_points,
+)
+
+
+class SAM2AutomaticMaskGenerator:
+ def __init__(
+ self,
+ model: SAM2Base,
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.8,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ mask_threshold: float = 0.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ use_m2m: bool = False,
+ multimask_output: bool = True,
+ ) -> None:
+ """
+ Using a SAM 2 model, generates masks for the entire image.
+ Generates a grid of point prompts over the image, then filters
+ low quality and duplicate masks. The default settings are chosen
+ for SAM 2 with a HieraL backbone.
+
+ Arguments:
+ model (Sam): The SAM 2 model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ mask_threshold (float): Threshold for binarizing the mask logits
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crop_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
+ multimask_output (bool): Whether to output multimask at each point of the grid.
+ """
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ try:
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+ except ImportError as e:
+ print("Please install pycocotools")
+ raise e
+
+ self.predictor = SAM2ImagePredictor(
+ model,
+ max_hole_area=min_mask_region_area,
+ max_sprinkle_area=min_mask_region_area,
+ )
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.mask_threshold = mask_threshold
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+ self.use_m2m = use_m2m
+ self.multimask_output = multimask_output
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """
+ Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Encode masks
+ if self.output_mode == "coco_rle":
+ mask_data["segmentations"] = [
+ coco_encode_rle(rle) for rle in mask_data["rles"]
+ ]
+ elif self.output_mode == "binary_mask":
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+ else:
+ mask_data["segmentations"] = mask_data["rles"]
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data["segmentations"])):
+ ann = {
+ "segmentation": mask_data["segmentations"][idx],
+ "area": area_from_rle(mask_data["rles"][idx]),
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
+ "point_coords": [mask_data["points"][idx].tolist()],
+ "stability_score": mask_data["stability_score"][idx].item(),
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
+ )
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data["crop_boxes"])
+ scores = scores.to(data["boxes"].device)
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ scores,
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+ batch_data = self._process_batch(
+ points, cropped_im_size, crop_box, orig_size, normalize=True
+ )
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_predictor()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ data["iou_preds"],
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+ data["points"] = uncrop_points(data["points"], crop_box)
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ normalize=False,
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ points = torch.as_tensor(points, device=self.predictor.device)
+ in_points = self.predictor._transforms.transform_coords(
+ points, normalize=normalize, orig_hw=im_size
+ )
+ in_labels = torch.ones(
+ in_points.shape[0], dtype=torch.int, device=in_points.device
+ )
+ masks, iou_preds, low_res_masks = self.predictor._predict(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=self.multimask_output,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=points.repeat_interleave(masks.shape[1], dim=0),
+ low_res_masks=low_res_masks.flatten(0, 1),
+ )
+ del masks
+
+ if not self.use_m2m:
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate and filter by stability score
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+ else:
+ # One step refinement using previous mask predictions
+ in_points = self.predictor._transforms.transform_coords(
+ data["points"], normalize=normalize, orig_hw=im_size
+ )
+ labels = torch.ones(
+ in_points.shape[0], dtype=torch.int, device=in_points.device
+ )
+ masks, ious = self.refine_with_m2m(
+ in_points, labels, data["low_res_masks"], self.points_per_batch
+ )
+ data["masks"] = masks.squeeze(1)
+ data["iou_preds"] = ious.squeeze(1)
+
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data["masks"] = data["masks"] > self.mask_threshold
+ data["boxes"] = batched_mask_to_box(data["masks"])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
+ )
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
+ del data["masks"]
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(
+ mask_data: MaskData, min_area: int, nms_thresh: float
+ ) -> MaskData:
+ """
+ Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data["rles"]) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data["rles"]:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros_like(boxes[:, 0]), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
+
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
+ new_masks = []
+ new_iou_preds = []
+
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
+ points_per_batch, points, point_labels, low_res_masks
+ ):
+ best_masks, best_iou_preds, _ = self.predictor._predict(
+ cur_points[:, None, :],
+ cur_point_labels[:, None],
+ mask_input=low_res_mask[:, None, :],
+ multimask_output=False,
+ return_logits=True,
+ )
+ new_masks.append(best_masks)
+ new_iou_preds.append(best_iou_preds)
+ masks = torch.cat(new_masks, dim=0)
+ return masks, torch.cat(new_iou_preds, dim=0)
diff --git a/sam2/build_sam.py b/sam2/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab2d0922aa3d8f869afb20b27748d98304e175f0
--- /dev/null
+++ b/sam2/build_sam.py
@@ -0,0 +1,111 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+import torch
+from hydra import compose
+from hydra.utils import instantiate
+from omegaconf import OmegaConf
+
+from .utils.misc import VARIANTS, variant_to_config_mapping
+
+
+def load_model(
+ variant: str,
+ ckpt_path=None,
+ device="cpu",
+ mode="eval",
+ hydra_overrides_extra=[],
+ apply_postprocessing=True,
+) -> torch.nn.Module:
+ assert variant in VARIANTS, f"only accepted variants are {VARIANTS}"
+
+ return build_sam2(
+ config_file=variant_to_config_mapping[variant],
+ ckpt_path=ckpt_path,
+ device=device,
+ mode=mode,
+ hydra_overrides_extra=hydra_overrides_extra,
+ apply_postprocessing=apply_postprocessing,
+ )
+
+
+def build_sam2(
+ config_file,
+ ckpt_path=None,
+ device="cpu",
+ mode="eval",
+ hydra_overrides_extra=[],
+ apply_postprocessing=True,
+):
+
+ if apply_postprocessing:
+ hydra_overrides_extra = hydra_overrides_extra.copy()
+ hydra_overrides_extra += [
+ # dynamically fall back to multi-mask if the single mask is not stable
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
+ ]
+ # Read config and init model
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
+ OmegaConf.resolve(cfg)
+ model = instantiate(cfg.model, _recursive_=True)
+ _load_checkpoint(model, ckpt_path)
+ model = model.to(device)
+ if mode == "eval":
+ model.eval()
+ return model
+
+
+def build_sam2_video_predictor(
+ config_file,
+ ckpt_path=None,
+ device="cpu",
+ mode="eval",
+ hydra_overrides_extra=[],
+ apply_postprocessing=True,
+):
+ hydra_overrides = [
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
+ ]
+ if apply_postprocessing:
+ hydra_overrides_extra = hydra_overrides_extra.copy()
+ hydra_overrides_extra += [
+ # dynamically fall back to multi-mask if the single mask is not stable
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
+ # "++model.fill_hole_area=8",
+ ]
+ hydra_overrides.extend(hydra_overrides_extra)
+
+ # Read config and init model
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
+ OmegaConf.resolve(cfg)
+ model = instantiate(cfg.model, _recursive_=True)
+ _load_checkpoint(model, ckpt_path)
+ model = model.to(device)
+ if mode == "eval":
+ model.eval()
+ return model
+
+
+def _load_checkpoint(model, ckpt_path):
+ if ckpt_path is not None:
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
+ if missing_keys:
+ logging.error(missing_keys)
+ raise RuntimeError()
+ if unexpected_keys:
+ logging.error(unexpected_keys)
+ raise RuntimeError()
+ logging.info("Loaded checkpoint sucessfully")
diff --git a/sam2/configs/__init__.py b/sam2/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2/configs/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2/configs/sam2_hiera_b+.yaml b/sam2/configs/sam2_hiera_b+.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..58f3eb81554018e873f8515ecb98e36d16ac29e4
--- /dev/null
+++ b/sam2/configs/sam2_hiera_b+.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 112
+ num_heads: 2
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [896, 448, 224, 112]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/sam2/configs/sam2_hiera_l.yaml b/sam2/configs/sam2_hiera_l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..918667f50c3e1ad2dcf77c0c14cb4dd114cfd080
--- /dev/null
+++ b/sam2/configs/sam2_hiera_l.yaml
@@ -0,0 +1,117 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 144
+ num_heads: 2
+ stages: [2, 6, 36, 4]
+ global_att_blocks: [23, 33, 43]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ window_spec: [8, 4, 16, 8]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [1152, 576, 288, 144]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/sam2/configs/sam2_hiera_s.yaml b/sam2/configs/sam2_hiera_s.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2
--- /dev/null
+++ b/sam2/configs/sam2_hiera_s.yaml
@@ -0,0 +1,116 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 11, 2]
+ global_att_blocks: [7, 10, 13]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ compile_image_encoder: False
diff --git a/sam2/configs/sam2_hiera_t.yaml b/sam2/configs/sam2_hiera_t.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a62c903aaa5f80828077c6e06a59626926570ed6
--- /dev/null
+++ b/sam2/configs/sam2_hiera_t.yaml
@@ -0,0 +1,118 @@
+# @package _global_
+
+# Model
+model:
+ _target_: sam2.modeling.sam2_base.SAM2Base
+ image_encoder:
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+ scalp: 1
+ trunk:
+ _target_: sam2.modeling.backbones.hieradet.Hiera
+ embed_dim: 96
+ num_heads: 1
+ stages: [1, 2, 7, 2]
+ global_att_blocks: [5, 7, 9]
+ window_pos_embed_bkg_spatial_size: [7, 7]
+ neck:
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 256
+ normalize: true
+ scale: null
+ temperature: 10000
+ d_model: 256
+ backbone_channel_list: [768, 384, 192, 96]
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
+ fpn_interp_model: nearest
+
+ memory_attention:
+ _target_: sam2.modeling.memory_attention.MemoryAttention
+ d_model: 256
+ pos_enc_at_input: true
+ layer:
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+ activation: relu
+ dim_feedforward: 2048
+ dropout: 0.1
+ pos_enc_at_attn: false
+ self_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ d_model: 256
+ pos_enc_at_cross_attn_keys: true
+ pos_enc_at_cross_attn_queries: false
+ cross_attention:
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
+ rope_theta: 10000.0
+ feat_sizes: [32, 32]
+ rope_k_repeat: True
+ embedding_dim: 256
+ num_heads: 1
+ downsample_rate: 1
+ dropout: 0.1
+ kv_in_dim: 64
+ num_layers: 4
+
+ memory_encoder:
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
+ out_dim: 64
+ position_encoding:
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+ num_pos_feats: 64
+ normalize: true
+ scale: null
+ temperature: 10000
+ mask_downsampler:
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
+ kernel_size: 3
+ stride: 2
+ padding: 1
+ fuser:
+ _target_: sam2.modeling.memory_encoder.Fuser
+ layer:
+ _target_: sam2.modeling.memory_encoder.CXBlock
+ dim: 256
+ kernel_size: 7
+ padding: 3
+ layer_scale_init_value: 1e-6
+ use_dwconv: True # depth-wise convs
+ num_layers: 2
+
+ num_maskmem: 7
+ image_size: 1024
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+ # SAM decoder
+ sigmoid_scale_for_mem_enc: 20.0
+ sigmoid_bias_for_mem_enc: -10.0
+ use_mask_input_as_output_without_sam: true
+ # Memory
+ directly_add_no_mem_embed: true
+ # use high-resolution feature map in the SAM mask decoder
+ use_high_res_features_in_sam: true
+ # output 3 masks on the first click on initial conditioning frames
+ multimask_output_in_sam: true
+ # SAM heads
+ iou_prediction_use_sigmoid: True
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder: true
+ add_tpos_enc_to_obj_ptrs: false
+ only_obj_ptrs_in_the_past_for_eval: true
+ # object occlusion prediction
+ pred_obj_scores: true
+ pred_obj_scores_mlp: true
+ fixed_no_obj_ptr: true
+ # multimask tracking settings
+ multimask_output_for_tracking: true
+ use_multimask_token_for_obj_ptr: true
+ multimask_min_pt_num: 0
+ multimask_max_pt_num: 1
+ use_mlp_for_obj_ptr_proj: true
+ # Compilation flag
+ # HieraT does not currently support compilation, should always be set to False
+ compile_image_encoder: False
diff --git a/sam2/modeling/__init__.py b/sam2/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2/modeling/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2/modeling/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b2b5255f11a31b524f22635c852e7deed04ae8b
Binary files /dev/null and b/sam2/modeling/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc b/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a9d8fca6e3c9fc260526640253d3d338c7ef2c1
Binary files /dev/null and b/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc differ
diff --git a/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc b/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1aebed0f43913e0d424a78d3e46a3c05e8897f04
Binary files /dev/null and b/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc differ
diff --git a/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc b/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6508fe8bb4ddfa05b4c689d39ce2de6404514ea3
Binary files /dev/null and b/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc differ
diff --git a/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc b/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..380ff5fdbfc7a3ac1e45ccf8d1b26b8d00ce4a8f
Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc differ
diff --git a/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc b/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8efda8a2e63f373e68d87dba7f8733fa5c0f0d27
Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc differ
diff --git a/sam2/modeling/backbones/__init__.py b/sam2/modeling/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2/modeling/backbones/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45f275aa7a5bb7b0a43e1801dd1e41567cca7e36
Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9952a40a7ebd152ca403871df2fa5bfe0ab120e4
Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc differ
diff --git a/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47478d488bd0b95c825e899e332481e5e8158d29
Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc differ
diff --git a/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe288f242b749edd87bba50c9d90b922b9040a1c
Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc differ
diff --git a/sam2/modeling/backbones/hieradet.py b/sam2/modeling/backbones/hieradet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c6d3b9fc82ac13ff1bde50312eb1cce517eb776
--- /dev/null
+++ b/sam2/modeling/backbones/hieradet.py
@@ -0,0 +1,294 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import partial
+from typing import List, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sam2.modeling.backbones.utils import (
+ PatchEmbed,
+ window_partition,
+ window_unpartition,
+)
+from sam2.modeling.sam2_utils import MLP, DropPath
+
+
+def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
+ if pool is None:
+ return x
+ # (B, H, W, C) -> (B, C, H, W)
+ x = x.permute(0, 3, 1, 2)
+ x = pool(x)
+ # (B, C, H', W') -> (B, H', W', C)
+ x = x.permute(0, 2, 3, 1)
+ if norm:
+ x = norm(x)
+
+ return x
+
+
+class MultiScaleAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ q_pool: nn.Module = None,
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.dim_out = dim_out
+
+ self.num_heads = num_heads
+ head_dim = dim_out // num_heads
+ self.scale = head_dim**-0.5
+
+ self.q_pool = q_pool
+ self.qkv = nn.Linear(dim, dim_out * 3)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (B, H * W, 3, nHead, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
+ # q, k, v with shape (B, H * W, nheads, C)
+ q, k, v = torch.unbind(qkv, 2)
+
+ # Q pooling (for downsample at stage changes)
+ if self.q_pool:
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
+ H, W = q.shape[1:3] # downsampled shape
+ q = q.reshape(B, H * W, self.num_heads, -1)
+
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
+ x = F.scaled_dot_product_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ )
+ # Transpose back
+ x = x.transpose(1, 2)
+ x = x.reshape(B, H, W, -1)
+
+ x = self.proj(x)
+
+ return x
+
+
+class MultiScaleBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ drop_path: float = 0.0,
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
+ q_stride: Tuple[int, int] = None,
+ act_layer: nn.Module = nn.GELU,
+ window_size: int = 0,
+ ):
+ super().__init__()
+
+ if isinstance(norm_layer, str):
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.norm1 = norm_layer(dim)
+
+ self.window_size = window_size
+
+ self.pool, self.q_stride = None, q_stride
+ if self.q_stride:
+ self.pool = nn.MaxPool2d(
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
+ )
+
+ self.attn = MultiScaleAttention(
+ dim,
+ dim_out,
+ num_heads=num_heads,
+ q_pool=self.pool,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim_out)
+ self.mlp = MLP(
+ dim_out,
+ int(dim_out * mlp_ratio),
+ dim_out,
+ num_layers=2,
+ activation=act_layer,
+ )
+
+ if dim != dim_out:
+ self.proj = nn.Linear(dim, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x # B, H, W, C
+ x = self.norm1(x)
+
+ # Skip connection
+ if self.dim != self.dim_out:
+ shortcut = do_pool(self.proj(x), self.pool)
+
+ # Window partition
+ window_size = self.window_size
+ if window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, window_size)
+
+ # Window Attention + Q Pooling (if stage change)
+ x = self.attn(x)
+ if self.q_stride:
+ # Shapes have changed due to Q pooling
+ window_size = self.window_size // self.q_stride[0]
+ H, W = shortcut.shape[1:3]
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ pad_hw = (H + pad_h, W + pad_w)
+
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
+
+ x = shortcut + self.drop_path(x)
+ # MLP
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class Hiera(nn.Module):
+ """
+ Reference: https://arxiv.org/abs/2306.00989
+ """
+
+ def __init__(
+ self,
+ embed_dim: int = 96, # initial embed dim
+ num_heads: int = 1, # initial number of heads
+ drop_path_rate: float = 0.0, # stochastic depth
+ q_pool: int = 3, # number of q_pool stages
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
+ head_mul: float = 2.0, # head_mul factor at stage shift
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
+ # window size per stage, when not using global att.
+ window_spec: Tuple[int, ...] = (
+ 8,
+ 4,
+ 14,
+ 7,
+ ),
+ # global attn in these blocks
+ global_att_blocks: Tuple[int, ...] = (
+ 12,
+ 16,
+ 20,
+ ),
+ return_interm_layers=True, # return feats from every stage
+ ):
+ super().__init__()
+
+ assert len(stages) == len(window_spec)
+ self.window_spec = window_spec
+
+ depth = sum(stages)
+ self.q_stride = q_stride
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
+ self.return_interm_layers = return_interm_layers
+
+ self.patch_embed = PatchEmbed(
+ embed_dim=embed_dim,
+ )
+ # Which blocks have global att?
+ self.global_att_blocks = global_att_blocks
+
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
+ )
+ self.pos_embed_window = nn.Parameter(
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
+ )
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+
+ cur_stage = 1
+ self.blocks = nn.ModuleList()
+
+ for i in range(depth):
+ dim_out = embed_dim
+ # lags by a block, so first block of
+ # next stage uses an initial window size
+ # of previous stage and final window size of current stage
+ window_size = self.window_spec[cur_stage - 1]
+
+ if self.global_att_blocks is not None:
+ window_size = 0 if i in self.global_att_blocks else window_size
+
+ if i - 1 in self.stage_ends:
+ dim_out = int(embed_dim * dim_mul)
+ num_heads = int(num_heads * head_mul)
+ cur_stage += 1
+
+ block = MultiScaleBlock(
+ dim=embed_dim,
+ dim_out=dim_out,
+ num_heads=num_heads,
+ drop_path=dpr[i],
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
+ window_size=window_size,
+ )
+
+ embed_dim = dim_out
+ self.blocks.append(block)
+
+ self.channel_list = (
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
+ if return_interm_layers
+ else [self.blocks[-1].dim_out]
+ )
+
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
+ h, w = hw
+ window_embed = self.pos_embed_window
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+ pos_embed = pos_embed + window_embed.tile(
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
+ )
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
+ return pos_embed
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ x = self.patch_embed(x)
+ # x: (B, H, W, C)
+
+ # Add pos embed
+ x = x + self._get_pos_embed(x.shape[1:3])
+
+ outputs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if (i == self.stage_ends[-1]) or (
+ i in self.stage_ends and self.return_interm_layers
+ ):
+ feats = x.permute(0, 3, 1, 2)
+ outputs.append(feats)
+
+ return outputs
diff --git a/sam2/modeling/backbones/image_encoder.py b/sam2/modeling/backbones/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f92baf47dcab96385ff99899fd3e3a642c1cf9c
--- /dev/null
+++ b/sam2/modeling/backbones/image_encoder.py
@@ -0,0 +1,133 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ImageEncoder(nn.Module):
+ def __init__(
+ self,
+ trunk: nn.Module,
+ neck: nn.Module,
+ scalp: int = 0,
+ ):
+ super().__init__()
+ self.trunk = trunk
+ self.neck = neck
+ self.scalp = scalp
+ assert (
+ self.trunk.channel_list == self.neck.backbone_channel_list
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
+
+ def forward(self, sample: torch.Tensor):
+ # Forward through backbone
+ features, pos = self.neck(self.trunk(sample))
+ if self.scalp > 0:
+ # Discard the lowest resolution features
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
+
+ src = features[-1]
+ output = {
+ "vision_features": src,
+ "vision_pos_enc": pos,
+ "backbone_fpn": features,
+ }
+ return output
+
+
+class FpnNeck(nn.Module):
+ """
+ A modified variant of Feature Pyramid Network (FPN) neck
+ (we remove output conv and also do bicubic interpolation similar to ViT
+ pos embed interpolation)
+ """
+
+ def __init__(
+ self,
+ position_encoding: nn.Module,
+ d_model: int,
+ backbone_channel_list: List[int],
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ fpn_interp_model: str = "bilinear",
+ fuse_type: str = "sum",
+ fpn_top_down_levels: Optional[List[int]] = None,
+ ):
+ """Initialize the neck
+ :param trunk: the backbone
+ :param position_encoding: the positional encoding to use
+ :param d_model: the dimension of the model
+ :param neck_norm: the normalization to use
+ """
+ super().__init__()
+ self.position_encoding = position_encoding
+ self.convs = nn.ModuleList()
+ self.backbone_channel_list = backbone_channel_list
+ for dim in backbone_channel_list:
+ current = nn.Sequential()
+ current.add_module(
+ "conv",
+ nn.Conv2d(
+ in_channels=dim,
+ out_channels=d_model,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ ),
+ )
+
+ self.convs.append(current)
+ self.fpn_interp_model = fpn_interp_model
+ assert fuse_type in ["sum", "avg"]
+ self.fuse_type = fuse_type
+
+ # levels to have top-down features in its outputs
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
+ # have top-down propagation, while outputs of level 0 and level 1 have only
+ # lateral features from the same backbone level.
+ if fpn_top_down_levels is None:
+ # default is to have top-down features on all levels
+ fpn_top_down_levels = range(len(self.convs))
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
+
+ def forward(self, xs: List[torch.Tensor]):
+
+ out = [None] * len(self.convs)
+ pos = [None] * len(self.convs)
+ assert len(xs) == len(self.convs)
+ # fpn forward pass
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
+ prev_features = None
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ x = xs[i]
+ lateral_features = self.convs[n - i](x)
+ if i in self.fpn_top_down_levels and prev_features is not None:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode=self.fpn_interp_model,
+ align_corners=(
+ None if self.fpn_interp_model == "nearest" else False
+ ),
+ antialias=False,
+ )
+ prev_features = lateral_features + top_down_features
+ if self.fuse_type == "avg":
+ prev_features /= 2
+ else:
+ prev_features = lateral_features
+ x_out = prev_features
+ out[i] = x_out
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
+
+ return out, pos
diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7
--- /dev/null
+++ b/sam2/modeling/backbones/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Some utilities for backbones, in particular for windowing"""
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def window_partition(x, window_size):
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ )
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(windows, window_size, pad_hw, hw):
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
+ )
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, ...] = (7, 7),
+ stride: Tuple[int, ...] = (4, 4),
+ padding: Tuple[int, ...] = (3, 3),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ):
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
diff --git a/sam2/modeling/memory_attention.py b/sam2/modeling/memory_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..00d7083200e2d3a1c378f5bec5239ebcd9036f18
--- /dev/null
+++ b/sam2/modeling/memory_attention.py
@@ -0,0 +1,168 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional
+
+import torch
+from torch import Tensor, nn
+
+from sam2.modeling.sam2_utils import get_activation_fn, get_clones
+from sam2.modeling.sam.transformer import RoPEAttention
+
+
+class MemoryAttentionLayer(nn.Module):
+
+ def __init__(
+ self,
+ activation: str,
+ cross_attention: nn.Module,
+ d_model: int,
+ dim_feedforward: int,
+ dropout: float,
+ pos_enc_at_attn: bool,
+ pos_enc_at_cross_attn_keys: bool,
+ pos_enc_at_cross_attn_queries: bool,
+ self_attention: nn.Module,
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.dim_feedforward = dim_feedforward
+ self.dropout_value = dropout
+ self.self_attn = self_attention
+ self.cross_attn_image = cross_attention
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation_str = activation
+ self.activation = get_activation_fn(activation)
+
+ # Where to add pos enc
+ self.pos_enc_at_attn = pos_enc_at_attn
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
+
+ def _forward_sa(self, tgt, query_pos):
+ # Self-Attention
+ tgt2 = self.norm1(tgt)
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
+ tgt2 = self.self_attn(q, k, v=tgt2)
+ tgt = tgt + self.dropout1(tgt2)
+ return tgt
+
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
+ kwds = {}
+ if num_k_exclude_rope > 0:
+ assert isinstance(self.cross_attn_image, RoPEAttention)
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
+
+ # Cross-Attention
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.cross_attn_image(
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
+ v=memory,
+ **kwds,
+ )
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ num_k_exclude_rope: int = 0,
+ ) -> torch.Tensor:
+
+ # Self-Attn, Cross-Attn
+ tgt = self._forward_sa(tgt, query_pos)
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
+ # MLP
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+
+class MemoryAttention(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ pos_enc_at_input: bool,
+ layer: nn.Module,
+ num_layers: int,
+ batch_first: bool = True, # Do layers expect batch first input?
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.layers = get_clones(layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = nn.LayerNorm(d_model)
+ self.pos_enc_at_input = pos_enc_at_input
+ self.batch_first = batch_first
+
+ def forward(
+ self,
+ curr: torch.Tensor, # self-attention inputs
+ memory: torch.Tensor, # cross-attention inputs
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
+ ):
+ if isinstance(curr, list):
+ assert isinstance(curr_pos, list)
+ assert len(curr) == len(curr_pos) == 1
+ curr, curr_pos = (
+ curr[0],
+ curr_pos[0],
+ )
+
+ assert (
+ curr.shape[1] == memory.shape[1]
+ ), "Batch size must be the same for curr and memory"
+
+ output = curr
+ if self.pos_enc_at_input and curr_pos is not None:
+ output = output + 0.1 * curr_pos
+
+ if self.batch_first:
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+ memory = memory.transpose(0, 1)
+ memory_pos = memory_pos.transpose(0, 1)
+
+ for layer in self.layers:
+ kwds = {}
+ if isinstance(layer.cross_attn_image, RoPEAttention):
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
+
+ output = layer(
+ tgt=output,
+ memory=memory,
+ pos=memory_pos,
+ query_pos=curr_pos,
+ **kwds,
+ )
+ normed_output = self.norm(output)
+
+ if self.batch_first:
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+
+ return normed_output
diff --git a/sam2/modeling/memory_encoder.py b/sam2/modeling/memory_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8b2df7aa754cd1a69e6231aad16ebaa10214411
--- /dev/null
+++ b/sam2/modeling/memory_encoder.py
@@ -0,0 +1,181 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones
+
+
+class MaskDownSampler(nn.Module):
+ """
+ Progressively downsample a mask by total_stride, each time by stride.
+ Note that LayerNorm is applied per *token*, like in ViT.
+
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
+ In the end, we linearly project to embed_dim channels.
+ """
+
+ def __init__(
+ self,
+ embed_dim=256,
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ total_stride=16,
+ activation=nn.GELU,
+ ):
+ super().__init__()
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
+ assert stride**num_layers == total_stride
+ self.encoder = nn.Sequential()
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (stride**2)
+ self.encoder.append(
+ nn.Conv2d(
+ mask_in_chans,
+ mask_out_chans,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+ )
+ self.encoder.append(LayerNorm2d(mask_out_chans))
+ self.encoder.append(activation())
+ mask_in_chans = mask_out_chans
+
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+
+ def forward(self, x):
+ return self.encoder(x)
+
+
+# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
+class CXBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+
+ def __init__(
+ self,
+ dim,
+ kernel_size=7,
+ padding=3,
+ drop_path=0.0,
+ layer_scale_init_value=1e-6,
+ use_dwconv=True,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=dim if use_dwconv else 1,
+ ) # depthwise conv
+ self.norm = LayerNorm2d(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, 4 * dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = self.norm(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class Fuser(nn.Module):
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
+ super().__init__()
+ self.proj = nn.Identity()
+ self.layers = get_clones(layer, num_layers)
+
+ if input_projection:
+ assert dim is not None
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
+
+ def forward(self, x):
+ # normally x: (N, C, H, W)
+ x = self.proj(x)
+ for layer in self.layers:
+ x = layer(x)
+ return x
+
+
+class MemoryEncoder(nn.Module):
+ def __init__(
+ self,
+ out_dim,
+ mask_downsampler,
+ fuser,
+ position_encoding,
+ in_dim=256, # in_dim of pix_feats
+ ):
+ super().__init__()
+
+ self.mask_downsampler = mask_downsampler
+
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
+ self.fuser = fuser
+ self.position_encoding = position_encoding
+ self.out_proj = nn.Identity()
+ if out_dim != in_dim:
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+
+ def forward(
+ self,
+ pix_feat: torch.Tensor,
+ masks: torch.Tensor,
+ skip_mask_sigmoid: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ ## Process masks
+ # sigmoid, so that less domain shift from gt masks which are bool
+ if not skip_mask_sigmoid:
+ masks = F.sigmoid(masks)
+ masks = self.mask_downsampler(masks)
+
+ ## Fuse pix_feats and downsampled masks
+ # in case the visual features are on CPU, cast them to CUDA
+ pix_feat = pix_feat.to(masks.device)
+
+ x = self.pix_feat_proj(pix_feat)
+ x = x + masks
+ x = self.fuser(x)
+ x = self.out_proj(x)
+
+ pos = self.position_encoding(x).to(x.dtype)
+
+ return {"vision_features": x, "vision_pos_enc": [pos]}
diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a308c30b69fd1a216fb1159b69261688a1792f6
--- /dev/null
+++ b/sam2/modeling/position_encoding.py
@@ -0,0 +1,215 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Any, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self,
+ num_pos_feats,
+ temperature: int = 10000,
+ normalize: bool = True,
+ scale: Optional[float] = None,
+ ):
+ super().__init__()
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
+ self.num_pos_feats = num_pos_feats // 2
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ self.cache = {}
+
+ def _encode_xy(self, x, y):
+ # The positions are expected to be normalized
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
+ x_embed = x * self.scale
+ y_embed = y * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, None] / dim_t
+ pos_y = y_embed[:, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
+ ).flatten(1)
+ pos_y = torch.stack(
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
+ ).flatten(1)
+ return pos_x, pos_y
+
+ @torch.no_grad()
+ def encode_boxes(self, x, y, w, h):
+ pos_x, pos_y = self._encode_xy(x, y)
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
+ return pos
+
+ encode = encode_boxes # Backwards compatibility
+
+ @torch.no_grad()
+ def encode_points(self, x, y, labels):
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
+ assert bx == by and nx == ny and bx == bl and nx == nl
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
+ return pos
+
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor):
+ cache_key = (x.shape[-2], x.shape[-1])
+ if cache_key in self.cache:
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
+ y_embed = (
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
+ .view(1, -1, 1)
+ .repeat(x.shape[0], 1, x.shape[-1])
+ )
+ x_embed = (
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
+ .view(1, 1, -1)
+ .repeat(x.shape[0], x.shape[-2], 1)
+ )
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ self.cache[cache_key] = pos[0]
+ return pos
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer(
+ "positional_encoding_gaussian_matrix",
+ scale * torch.randn((2, num_pos_feats)),
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+ ) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
+
+
+# Rotary Positional Encoding, adapted from:
+# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# 2. https://github.com/naver-ai/rope-vit
+# 3. https://github.com/lucidrains/rotary-embedding-torch
+
+
+def init_t_xy(end_x: int, end_y: int):
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
+ t_x = (t % end_x).float()
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
+ return t_x, t_y
+
+
+def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ t_x, t_y = init_t_xy(end_x, end_y)
+ freqs_x = torch.outer(t_x, freqs_x)
+ freqs_y = torch.outer(t_y, freqs_y)
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def apply_rotary_enc(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ repeat_freqs_k: bool = False,
+):
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = (
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ if xk.shape[-2] != 0
+ else None
+ )
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ if xk_ is None:
+ # no keys to rotate, due to dropout
+ return xq_out.type_as(xq).to(xq.device), xk
+ # repeat freqs along seq_len dim to match k seq_len
+ if repeat_freqs_k:
+ r = xk_.shape[-2] // xq_.shape[-2]
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
diff --git a/sam2/modeling/sam/__init__.py b/sam2/modeling/sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2/modeling/sam/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48ed33b9cb928c69f169915fe938e85609f56614
Binary files /dev/null and b/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6aab40a2eddf5f8d5fe0aa5f900c38672b1da744
Binary files /dev/null and b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc differ
diff --git a/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe20ea8edd26359ad5441798a7dfd34fb71a245a
Binary files /dev/null and b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc differ
diff --git a/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc b/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a11939782e033b2acef84e48e2fa37588a0af65
Binary files /dev/null and b/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc differ
diff --git a/sam2/modeling/sam/mask_decoder.py b/sam2/modeling/sam/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..29b488a57b05dcd85df714e73b37524a2e8897c8
--- /dev/null
+++ b/sam2/modeling/sam/mask_decoder.py
@@ -0,0 +1,295 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from sam2.modeling.sam2_utils import MLP, LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ use_high_res_features: bool = False,
+ iou_prediction_use_sigmoid=False,
+ dynamic_multimask_via_stability=False,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ pred_obj_scores: bool = False,
+ pred_obj_scores_mlp: bool = False,
+ use_multimask_token_for_obj_ptr: bool = False,
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ transformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.pred_obj_scores = pred_obj_scores
+ if self.pred_obj_scores:
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
+ ),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
+ ),
+ activation(),
+ )
+ self.use_high_res_features = use_high_res_features
+ if use_high_res_features:
+ self.conv_s0 = nn.Conv2d(
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
+ )
+ self.conv_s1 = nn.Conv2d(
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
+ )
+
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim,
+ iou_head_hidden_dim,
+ self.num_mask_tokens,
+ iou_head_depth,
+ sigmoid_output=iou_prediction_use_sigmoid,
+ )
+ if self.pred_obj_scores:
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
+ if pred_obj_scores_mlp:
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
+
+ # When outputting a single mask, optionally we can dynamically fall back to the best
+ # multimask output token if the single mask output token gives low stability scores.
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ torch.Tensor: batched SAM token for mask output
+ """
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ repeat_image=repeat_image,
+ high_res_features=high_res_features,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ masks = masks[:, 1:, :, :]
+ iou_pred = iou_pred[:, 1:]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ masks = masks[:, 0:1, :, :]
+ iou_pred = iou_pred[:, 0:1]
+
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
+ else:
+ # Take the mask output token. Here we *always* use the token for single mask output.
+ # At test time, even if we track after 1-click (and using multimask_output=True),
+ # we still take the single mask token here. The rationale is that we always track
+ # after multiple clicks during training, so the past tokens seen during training
+ # are always the single mask token (and we'll let it be the object-memory token).
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
+
+ # Prepare output
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ s = 0
+ if self.pred_obj_scores:
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ s = 1
+ else:
+ output_tokens = torch.cat(
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
+ )
+ output_tokens = output_tokens.unsqueeze(0).expand(
+ sparse_prompt_embeddings.size(0), -1, -1
+ )
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ if repeat_image:
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ else:
+ assert image_embeddings.shape[0] == tokens.shape[0]
+ src = image_embeddings
+ src = src + dense_prompt_embeddings
+ assert (
+ image_pe.size(0) == 1
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, s, :]
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ if not self.use_high_res_features:
+ upscaled_embedding = self.output_upscaling(src)
+ else:
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
+ feat_s0, feat_s1 = high_res_features
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
+ )
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ if self.pred_obj_scores:
+ assert s == 1
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+ else:
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+
+ return masks, iou_pred, mask_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """
+ Compute stability scores of the mask logits based on the IoU between upper and
+ lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
+ """
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ When outputting a single mask, if the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, we instead select from
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
+ batch_inds = torch.arange(
+ multimask_iou_scores.size(0), device=all_iou_scores.device
+ )
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e177a71a32b4ca9811f7b7535de50d988b3ad492
--- /dev/null
+++ b/sam2/modeling/sam/prompt_encoder.py
@@ -0,0 +1,181 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from sam2.modeling.position_encoding import PositionEmbeddingRandom
+from sam2.modeling.sam2_utils import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ activation: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ """
+ Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ mask_in_chans (int): The number of hidden channels used for
+ encoding input masks.
+ activation (nn.Module): The activation to use when encoding
+ input masks.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
+ ]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+ self.mask_input_size = (
+ 4 * image_embedding_size[0],
+ 4 * image_embedding_size[1],
+ )
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ activation(),
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ activation(),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ )
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """
+ Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ pad: bool,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(
+ points, self.input_image_size
+ )
+ point_embedding[labels == -1] = 0.0
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(
+ coords, self.input_image_size
+ )
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs."""
+ mask_embedding = self.mask_downscaling(masks)
+ return mask_embedding
+
+ def _get_batch_size(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """
+ Gets the batch size of the output given the batch size of the input prompts.
+ """
+ if points is not None:
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ else:
+ return 1
+
+ def _get_device(self) -> torch.device:
+ return self.point_embeddings[0].weight.device
+
+ def forward(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+ and labels to embed.
+ boxes (torch.Tensor or none): boxes to embed
+ masks (torch.Tensor or none): masks to embed
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ torch.Tensor: dense embeddings for the masks, in the shape
+ Bx(embed_dim)x(embed_H)x(embed_W)
+ """
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty(
+ (bs, 0, self.embed_dim), device=self._get_device()
+ )
+ if points is not None:
+ coords, labels = points
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..85baf152f418610eb0dcfa220f38dfc99ce53213
--- /dev/null
+++ b/sam2/modeling/sam/transformer.py
@@ -0,0 +1,317 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import warnings
+from functools import partial
+from typing import Tuple, Type
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
+from sam2.modeling.sam2_utils import MLP
+from sam2.utils.misc import get_sdp_backends
+
+warnings.simplefilter(action="ignore", category=FutureWarning)
+# OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+ self.final_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attention layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLP(
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
+ )
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ dropout: float = 0.0,
+ kv_in_dim: int = None,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert (
+ self.internal_dim % num_heads == 0
+ ), "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ self.dropout_p = dropout
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ dropout_p = self.dropout_p if self.training else 0.0
+ # Attention
+
+ #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
+
+
+class RoPEAttention(Attention):
+ """Attention with rotary position encoding."""
+
+ def __init__(
+ self,
+ *args,
+ rope_theta=10000.0,
+ # whether to repeat q rope to match k length
+ # this is needed for cross-attention to memories
+ rope_k_repeat=False,
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.compute_cis = partial(
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
+ )
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
+ self.freqs_cis = freqs_cis
+ self.rope_k_repeat = rope_k_repeat
+
+ def forward(
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
+ ) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Apply rotary position encoding
+ w = h = math.sqrt(q.shape[-2])
+ self.freqs_cis = self.freqs_cis.to(q.device)
+ if self.freqs_cis.shape[0] != q.shape[-2]:
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
+ if q.shape[-2] != k.shape[-2]:
+ assert self.rope_k_repeat
+
+ num_k_rope = k.size(-2) - num_k_exclude_rope
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
+ q,
+ k[:, :, :num_k_rope],
+ freqs_cis=self.freqs_cis,
+ repeat_freqs_k=self.rope_k_repeat,
+ )
+
+ dropout_p = self.dropout_p if self.training else 0.0
+
+ #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4326f448de8c9a652f0a841dad9620579b7976f4
--- /dev/null
+++ b/sam2/modeling/sam2_base.py
@@ -0,0 +1,828 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.distributed
+import torch.nn.functional as F
+from torch.nn.init import trunc_normal_
+
+from sam2.modeling.sam2_utils import MLP, get_1d_sine_pe, select_closest_cond_frames
+from sam2.modeling.sam.mask_decoder import MaskDecoder
+from sam2.modeling.sam.prompt_encoder import PromptEncoder
+from sam2.modeling.sam.transformer import TwoWayTransformer
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+class SAM2Base(torch.nn.Module):
+ def __init__(
+ self,
+ image_encoder,
+ memory_attention,
+ memory_encoder,
+ num_maskmem=7, # default 1 input frame + 6 previous frames
+ image_size=512,
+ backbone_stride=16, # stride of the image backbone output
+ sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
+ sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
+ # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
+ binarize_mask_from_pts_for_mem_enc=False,
+ use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
+ max_cond_frames_in_attn=-1,
+ # on the first frame, whether to directly add the no-memory embedding to the image feature
+ # (instead of using the transformer encoder)
+ directly_add_no_mem_embed=False,
+ # whether to use high-resolution feature maps in the SAM mask decoder
+ use_high_res_features_in_sam=False,
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
+ multimask_output_in_sam=False,
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
+ multimask_min_pt_num=1,
+ multimask_max_pt_num=1,
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
+ multimask_output_for_tracking=False,
+ # Whether to use multimask tokens for obj ptr; Only relevant when both
+ # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
+ use_multimask_token_for_obj_ptr: bool = False,
+ # whether to use sigmoid to restrict ious prediction to [0-1]
+ iou_prediction_use_sigmoid=False,
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
+ memory_temporal_stride_for_eval=1,
+ # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
+ # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
+ add_all_frames_to_correct_as_cond=False,
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
+ non_overlap_masks_for_mem_enc=False,
+ # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+ use_obj_ptrs_in_encoder=False,
+ # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
+ max_obj_ptrs_in_encoder=16,
+ # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
+ add_tpos_enc_to_obj_ptrs=True,
+ # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
+ # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
+ proj_tpos_enc_in_obj_ptrs=False,
+ # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
+ # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
+ only_obj_ptrs_in_the_past_for_eval=False,
+ # Whether to predict if there is an object in the frame
+ pred_obj_scores: bool = False,
+ # Whether to use an MLP to predict object scores
+ pred_obj_scores_mlp: bool = False,
+ # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
+ # Whether to have a fixed no obj pointer when there is no object present
+ # or to use it as an additive embedding with obj_ptr produced by decoder
+ fixed_no_obj_ptr: bool = False,
+ # Soft no object, i.e. mix in no_obj_ptr softly,
+ # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
+ soft_no_obj_ptr: bool = False,
+ use_mlp_for_obj_ptr_proj: bool = False,
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
+ sam_mask_decoder_extra_args=None,
+ compile_image_encoder: bool = False,
+ ):
+ super().__init__()
+
+ # Part 1: the image backbone
+ self.image_encoder = image_encoder
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
+ if use_obj_ptrs_in_encoder:
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
+ if proj_tpos_enc_in_obj_ptrs:
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
+
+ # Part 2: memory attention to condition current frame's visual features
+ # with memories (and obj ptrs) from past frames
+ self.memory_attention = memory_attention
+ self.hidden_dim = memory_attention.d_model
+
+ # Part 3: memory encoder for the previous frame's outputs
+ self.memory_encoder = memory_encoder
+ self.mem_dim = self.hidden_dim
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(
+ self.memory_encoder.out_proj, "weight"
+ ):
+ # if there is compression of memories along channel dim
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
+ self.num_maskmem = num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.maskmem_tpos_enc = torch.nn.Parameter(
+ torch.zeros(num_maskmem, 1, 1, self.mem_dim)
+ )
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
+ # a single token to indicate no memory embedding from previous frames
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ trunc_normal_(self.no_mem_embed, std=0.02)
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
+ # Apply sigmoid to the output raw mask logits (to turn them from
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
+ # On frames with mask input, whether to directly output the input mask without
+ # using a SAM prompt encoder + mask decoder
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
+
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
+ # and SAM-style mask decoder for the final mask output
+ self.image_size = image_size
+ self.backbone_stride = backbone_stride
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
+ self.pred_obj_scores = pred_obj_scores
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
+ self.soft_no_obj_ptr = soft_no_obj_ptr
+ if self.fixed_no_obj_ptr:
+ assert self.pred_obj_scores
+ assert self.use_obj_ptrs_in_encoder
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ trunc_normal_(self.no_obj_ptr, std=0.02)
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
+
+ self._build_sam_heads()
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
+
+ # Model compilation
+ if compile_image_encoder:
+ # Compile the forward function (not the full module) to allow loading checkpoints.
+ print(
+ "Image encoder compilation is enabled. First forward pass will be slow."
+ )
+ self.image_encoder.forward = torch.compile(
+ self.image_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False,
+ )
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Please use the corresponding methods in SAM2VideoPredictor for inference."
+ "See notebooks/video_predictor_example.ipynb for an example."
+ )
+
+ def _build_sam_heads(self):
+ """Build SAM-style prompt encoder and mask decoder."""
+ self.sam_prompt_embed_dim = self.hidden_dim
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
+
+ # build PromptEncoder and MaskDecoder from SAM
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
+ self.sam_prompt_encoder = PromptEncoder(
+ embed_dim=self.sam_prompt_embed_dim,
+ image_embedding_size=(
+ self.sam_image_embedding_size,
+ self.sam_image_embedding_size,
+ ),
+ input_image_size=(self.image_size, self.image_size),
+ mask_in_chans=16,
+ )
+ self.sam_mask_decoder = MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=self.sam_prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=self.sam_prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ use_high_res_features=self.use_high_res_features_in_sam,
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+ pred_obj_scores=self.pred_obj_scores,
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+ **(self.sam_mask_decoder_extra_args or {}),
+ )
+ if self.use_obj_ptrs_in_encoder:
+ # a linear projection on SAM output tokens to turn them into object pointers
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
+ if self.use_mlp_for_obj_ptr_proj:
+ self.obj_ptr_proj = MLP(
+ self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
+ )
+ else:
+ self.obj_ptr_proj = torch.nn.Identity()
+ if self.proj_tpos_enc_in_obj_ptrs:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ else:
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
+
+ def _forward_sam_heads(
+ self,
+ backbone_features,
+ point_inputs=None,
+ mask_inputs=None,
+ high_res_features=None,
+ multimask_output=False,
+ ):
+ """
+ Forward SAM prompt encoders and mask heads.
+
+ Inputs:
+ - backbone_features: image features of [B, C, H, W] shape
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
+ absolute pixel-unit coordinate in (x, y) format of the P input points
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
+ positive clicks, 0 means negative clicks, and -1 means padding
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
+ same spatial size as the image.
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
+ which will be used as high-resolution feature maps for SAM decoder.
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
+ its corresponding IoU estimate.
+
+ Outputs:
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
+ the resolution (1/4 stride) of the input backbone_features.
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
+ upsampled from the low-resolution masks, with shape size as the image
+ (stride is 1 pixel).
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
+ if `multimask_output=False`), the estimated IoU of each output mask.
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
+ based on the output token from the SAM mask decoder.
+ """
+ B = backbone_features.size(0)
+ device = backbone_features.device
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
+ assert backbone_features.size(2) == self.sam_image_embedding_size
+ assert backbone_features.size(3) == self.sam_image_embedding_size
+
+ # a) Handle point prompts
+ if point_inputs is not None:
+ sam_point_coords = point_inputs["point_coords"]
+ sam_point_labels = point_inputs["point_labels"]
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+ else:
+ # If no points are provide, pad with an empty point (with label -1)
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+ # b) Handle mask prompts
+ if mask_inputs is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+ sam_mask_prompt = F.interpolate(
+ mask_inputs.float(),
+ size=self.sam_prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ else:
+ sam_mask_prompt = mask_inputs
+ else:
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
+ # a learned `no_mask_embed` to indicate no mask input in this case).
+ sam_mask_prompt = None
+
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+ points=(sam_point_coords, sam_point_labels),
+ boxes=None,
+ masks=sam_mask_prompt,
+ )
+ (
+ low_res_multimasks,
+ ious,
+ sam_output_tokens,
+ object_score_logits,
+ ) = self.sam_mask_decoder(
+ image_embeddings=backbone_features,
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=False, # the image is already batched
+ high_res_features=high_res_features,
+ )
+ if self.pred_obj_scores:
+ is_obj_appearing = object_score_logits > 0
+
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ low_res_multimasks = low_res_multimasks.float()
+ high_res_multimasks = F.interpolate(
+ low_res_multimasks,
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ sam_output_token = sam_output_tokens[:, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(ious, dim=-1)
+ batch_inds = torch.arange(B, device=device)
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ if sam_output_tokens.size(1) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
+ if self.pred_obj_scores:
+ # Allow *soft* no obj ptr, unlike for masks
+ if self.soft_no_obj_ptr:
+ # Only hard possible with gt
+ assert not self.teacher_force_obj_scores_for_mem
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
+ else:
+ lambda_is_obj_appearing = is_obj_appearing.float()
+
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
+ """
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
+ (same input and output shapes as in _forward_sam_heads above).
+ """
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.float()
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ high_res_masks,
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ # a dummy IoU prediction of all 1's under mask input
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
+ if not self.use_obj_ptrs_in_encoder:
+ # all zeros as a dummy object pointer (of shape [B, C])
+ obj_ptr = torch.zeros(
+ mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
+ )
+ else:
+ # produce an object pointer using the SAM decoder from the mask input
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
+ backbone_features=backbone_features,
+ mask_inputs=self.mask_downsample(mask_inputs_float),
+ high_res_features=high_res_features,
+ )
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ lambda_is_obj_appearing = is_obj_appearing.float()
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ if self.pred_obj_scores:
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_masks,
+ high_res_masks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def forward_image(self, img_batch: torch.Tensor):
+ """Get the image feature on the input batch."""
+ backbone_out = self.image_encoder(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
+ backbone_out["backbone_fpn"][0]
+ )
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
+ backbone_out["backbone_fpn"][1]
+ )
+ return backbone_out
+
+ def _prepare_backbone_features(self, backbone_out):
+ """Prepare and flatten visual features."""
+ backbone_out = backbone_out.copy()
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+ # flatten NxCxHxW to HWxNxC
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+
+ def _prepare_memory_conditioned_features(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ ):
+ """Fuse the current frame's visual feature map with previous memory."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ device = current_vision_feats[-1].device
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
+ # In this case, we skip the fusion with any memory.
+ if self.num_maskmem == 0: # Disable memory and skip fusion
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat
+
+ num_obj_ptr_tokens = 0
+ # Step 1: condition the visual features of the current frame on previous memories
+ if not is_init_cond_frame:
+ # Retrieve the memories encoded with the maskmem backbone
+ to_cat_memory, to_cat_memory_pos_embed = [], []
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
+ # when getting temporal positional embedding below)
+ assert len(output_dict["cond_frame_outputs"]) > 0
+ # Select a maximum number of temporally closest cond frames for cross attention
+ cond_outputs = output_dict["cond_frame_outputs"]
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
+ )
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
+ # We also allow taking the memory frame non-consecutively (with r>1), in which case
+ # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
+ r = self.memory_temporal_stride_for_eval
+ for t_pos in range(1, self.num_maskmem):
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
+ if t_rel == 1:
+ # for t_rel == 1, we take the last frame (regardless of r)
+ if not track_in_reverse:
+ # the frame immediately before this frame (i.e. frame_idx - 1)
+ prev_frame_idx = frame_idx - t_rel
+ else:
+ # the frame immediately after this frame (i.e. frame_idx + 1)
+ prev_frame_idx = frame_idx + t_rel
+ else:
+ # for t_rel >= 2, we take the memory frame from every r-th frames
+ if not track_in_reverse:
+ # first find the nearest frame among every r-th frames before this frame
+ # for r=1, this would be (frame_idx - 2)
+ prev_frame_idx = ((frame_idx - 2) // r) * r
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
+ else:
+ # first find the nearest frame among every r-th frames after this frame
+ # for r=1, this would be (frame_idx + 2)
+ prev_frame_idx = -(-(frame_idx + 2) // r) * r
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
+ if out is None:
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
+ # frames, we still attend to it as if it's a non-conditioning frame.
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
+ t_pos_and_prevs.append((t_pos, out))
+
+ for t_pos, prev in t_pos_and_prevs:
+ if prev is None:
+ continue # skip padding frames
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
+ feats = prev["maskmem_features"].to(self.device)
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
+ maskmem_enc = prev["maskmem_pos_enc"][-1].to(self.device)
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
+ # Temporal positional encoding
+ maskmem_enc = (
+ maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
+ )
+ to_cat_memory_pos_embed.append(maskmem_enc)
+
+ # Construct the list of past object pointers
+ if self.use_obj_ptrs_in_encoder:
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
+ # First add those object pointers from selected conditioning frames
+ # (optionally, only include object pointers in the past during evaluation)
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
+ ptr_cond_outputs = {
+ t: out
+ for t, out in selected_cond_outputs.items()
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
+ }
+ else:
+ ptr_cond_outputs = selected_cond_outputs
+ pos_and_ptrs = [
+ # Temporal pos encoding contains how far away each pointer is from current frame
+ (abs(frame_idx - t), out["obj_ptr"])
+ for t, out in ptr_cond_outputs.items()
+ ]
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
+ if t < 0 or (num_frames is not None and t >= num_frames):
+ break
+ out = output_dict["non_cond_frame_outputs"].get(
+ t, unselected_cond_outputs.get(t, None)
+ )
+ if out is not None:
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
+ # If we have at least one object pointer, add them to the across attention
+ if len(pos_and_ptrs) > 0:
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
+ # a temporal positional embedding based on how far each object pointer is from
+ # the current frame (sine embedding normalized by the max pointer num).
+ if self.add_tpos_enc_to_obj_ptrs:
+ t_diff_max = max_obj_ptrs_in_encoder - 1
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
+ obj_pos = torch.tensor(pos_list, device=device)
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
+ else:
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
+ if self.mem_dim < C:
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
+ obj_ptrs = obj_ptrs.reshape(
+ -1, B, C // self.mem_dim, self.mem_dim
+ )
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
+ to_cat_memory.append(obj_ptrs)
+ to_cat_memory_pos_embed.append(obj_pos)
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
+ else:
+ num_obj_ptr_tokens = 0
+ else:
+ # for initial conditioning frames, encode them without using any previous memory
+ if self.directly_add_no_mem_embed:
+ # directly add no-mem embedding (instead of using the transformer encoder)
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
+
+ # Step 2: Concatenate the memories and forward through the transformer encoder
+ memory = torch.cat(to_cat_memory, dim=0)
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
+
+ pix_feat_with_mem = self.memory_attention(
+ curr=current_vision_feats,
+ curr_pos=current_vision_pos_embeds,
+ memory=memory,
+ memory_pos=memory_pos_embed,
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
+ )
+ # reshape the output (HW)BC => BCHW
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ pred_masks_high_res,
+ is_mask_from_pts,
+ ):
+ """Encode the current image and its prediction into a memory feature."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ if self.non_overlap_masks_for_mem_enc and not self.training:
+ # optionally, apply non-overlapping constraints to the masks (it's applied
+ # in the batch dimension and should only be used during eval, where all
+ # the objects come from the same video under batch size 1).
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
+ pred_masks_high_res
+ )
+ # scale the raw mask logits with a temperature before applying sigmoid
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+ if binarize and not self.training:
+ mask_for_mem = (pred_masks_high_res > 0).float()
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ if self.sigmoid_scale_for_mem_enc != 1.0:
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+ if self.sigmoid_bias_for_mem_enc != 0.0:
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+ maskmem_out = self.memory_encoder(
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
+ )
+ maskmem_features = maskmem_out["vision_features"]
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
+
+ return maskmem_features, maskmem_pos_enc
+
+ def track_step(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
+ # in demo we might call `track_step` multiple times for each user click,
+ # and only encode the memory when the user finalizes their clicks. And in ablation
+ # settings like SAM training on static images, we don't need the memory encoder.
+ run_mem_encoder=True,
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
+ prev_sam_mask_logits=None,
+ ):
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
+ ]
+ else:
+ high_res_features = None
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
+ sam_outputs = self._use_mask_as_output(
+ pix_feat, high_res_features, mask_inputs
+ )
+ else:
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats[-1:],
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
+ feat_sizes=feat_sizes[-1:],
+ output_dict=output_dict,
+ num_frames=num_frames,
+ track_in_reverse=track_in_reverse,
+ )
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ assert point_inputs is not None and mask_inputs is None
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._forward_sam_heads(
+ backbone_features=pix_feat_with_mem,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ high_res_features=high_res_features,
+ multimask_output=multimask_output,
+ )
+ (
+ _,
+ _,
+ _,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ _,
+ ) = sam_outputs
+
+ current_out["pred_masks"] = low_res_masks
+ current_out["pred_masks_high_res"] = high_res_masks
+ current_out["obj_ptr"] = obj_ptr
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (that can be used in future frames)
+ if run_mem_encoder and self.num_maskmem > 0:
+ high_res_masks_for_mem_enc = high_res_masks
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks_for_mem_enc,
+ is_mask_from_pts=(point_inputs is not None),
+ )
+ current_out["maskmem_features"] = maskmem_features
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ current_out["maskmem_features"] = None
+ current_out["maskmem_pos_enc"] = None
+
+ return current_out
+
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
+ """Whether to use multimask output in the SAM head."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
+ multimask_output = (
+ self.multimask_output_in_sam
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
+ )
+ return multimask_output
+
+ def _apply_non_overlapping_constraints(self, pred_masks):
+ """
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
+ keep only the highest scoring object at each spatial location in pred_masks.
+ """
+ batch_size = pred_masks.size(0)
+ if batch_size == 1:
+ return pred_masks
+
+ device = pred_masks.device
+ # "max_obj_inds": object index of the object with the highest score at each location
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+ keep = max_obj_inds == batch_obj_inds
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ return pred_masks
diff --git a/sam2/modeling/sam2_utils.py b/sam2/modeling/sam2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d9705963efc57d74b7d1bff31692d7d293a46ad
--- /dev/null
+++ b/sam2/modeling/sam2_utils.py
@@ -0,0 +1,149 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
+ """
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
+ - a) the closest conditioning frame before `frame_idx` (if any);
+ - b) the closest conditioning frame after `frame_idx` (if any);
+ - c) any other temporally closest conditioning frames until reaching a total
+ of `max_cond_frame_num` conditioning frames.
+
+ Outputs:
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
+ """
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
+ selected_outputs = cond_frame_outputs
+ unselected_outputs = {}
+ else:
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
+ selected_outputs = {}
+
+ # the closest conditioning frame before `frame_idx` (if any)
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
+ if idx_before is not None:
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
+
+ # the closest conditioning frame after `frame_idx` (if any)
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
+ if idx_after is not None:
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
+
+ # add other temporally closest conditioning frames until reaching a total
+ # of `max_cond_frame_num` conditioning frames.
+ num_remain = max_cond_frame_num - len(selected_outputs)
+ inds_remain = sorted(
+ (t for t in cond_frame_outputs if t not in selected_outputs),
+ key=lambda x: abs(x - frame_idx),
+ )[:num_remain]
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
+ unselected_outputs = {
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
+ }
+
+ return selected_outputs, unselected_outputs
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+ """
+ Get 1D sine positional embedding as in the original Transformer paper.
+ """
+ pe_dim = dim // 2
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+ return pos_embed
+
+
+def get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+
+def get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class DropPath(nn.Module):
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ if self.drop_prob == 0.0 or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and self.scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ activation: nn.Module = nn.ReLU,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+ self.sigmoid_output = sigmoid_output
+ self.act = activation()
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6444b0f225d7a99774e51e587cd9fa7d9f71ec39
--- /dev/null
+++ b/sam2/sam2_image_predictor.py
@@ -0,0 +1,445 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL.Image import Image
+
+from sam2.modeling.sam2_base import SAM2Base
+from sam2.utils.transforms import SAM2Transforms
+
+
+class SAM2ImagePredictor:
+ def __init__(
+ self,
+ sam_model: SAM2Base,
+ mask_threshold=0.0,
+ max_hole_area=0.0,
+ max_sprinkle_area=0.0,
+ ) -> None:
+ """
+ Uses SAM-2 to calculate the image embedding for an image, and then
+ allow repeated, efficient mask prediction given prompts.
+
+ Arguments:
+ sam_model (Sam-2): The model to use for mask prediction.
+ mask_threshold (float): The threshold to use when converting mask logits
+ to binary masks. Masks are thresholded at 0 by default.
+ fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
+ the maximum area of fill_hole_area in low_res_masks.
+ """
+ super().__init__()
+ self.model = sam_model
+ self._transforms = SAM2Transforms(
+ resolution=self.model.image_size,
+ mask_threshold=mask_threshold,
+ max_hole_area=max_hole_area,
+ max_sprinkle_area=max_sprinkle_area,
+ )
+
+ # Predictor state
+ self._is_image_set = False
+ self._features = None
+ self._orig_hw = None
+ # Whether the predictor is set for single image or a batch of images
+ self._is_batch = False
+
+ # Predictor config
+ self.mask_threshold = mask_threshold
+
+ # Spatial dim for backbone feature maps
+ self._bb_feat_sizes = [
+ (256, 256),
+ (128, 128),
+ (64, 64),
+ ]
+
+ @torch.no_grad()
+ def set_image(
+ self,
+ image: Union[np.ndarray, Image],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
+ with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ self.reset_predictor()
+ # Transform the image to the form expected by the model
+ if isinstance(image, np.ndarray):
+ logging.info("For numpy array image, we assume (HxWxC) format")
+ self._orig_hw = [image.shape[:2]]
+ elif isinstance(image, Image):
+ w, h = image.size
+ self._orig_hw = [(h, w)]
+ else:
+ raise NotImplementedError("Image format not supported")
+
+ input_image = self._transforms(image)
+ input_image = input_image[None, ...].to(self.device)
+
+ assert (
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
+ logging.info("Computing image embeddings for the provided image...")
+ backbone_out = self.model.forward_image(input_image)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+ print("feat_size", self._bb_feat_sizes[::-1])
+ feats = [
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+ self._is_image_set = True
+ logging.info("Image embeddings computed.")
+
+ @torch.no_grad()
+ def set_image_batch(
+ self,
+ image_list: List[Union[np.ndarray]],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image batch, allowing
+ masks to be predicted with the 'predict_batch' method.
+
+ Arguments:
+ image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
+ with pixel values in [0, 255].
+ """
+ self.reset_predictor()
+ assert isinstance(image_list, list)
+ self._orig_hw = []
+ for image in image_list:
+ assert isinstance(
+ image, np.ndarray
+ ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
+ self._orig_hw.append(image.shape[:2])
+ # Transform the image to the form expected by the model
+ img_batch = self._transforms.forward_batch(image_list)
+ img_batch = img_batch.to(self.device)
+ batch_size = img_batch.shape[0]
+ assert (
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
+ ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
+ logging.info("Computing image embeddings for the provided images...")
+ backbone_out = self.model.forward_image(img_batch)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+ feats = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+ self._is_image_set = True
+ self._is_batch = True
+ logging.info("Image embeddings computed.")
+
+ def predict_batch(
+ self,
+ point_coords_batch: List[np.ndarray] = None,
+ point_labels_batch: List[np.ndarray] = None,
+ box_batch: List[np.ndarray] = None,
+ mask_input_batch: List[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ normalize_coords=True,
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
+ """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
+ It returns a tupele of lists of masks, ious, and low_res_masks_logits.
+ """
+ assert self._is_batch, "This function should only be used when in batched mode"
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image_batch(...) before mask prediction."
+ )
+ num_images = len(self._features["image_embed"])
+ all_masks = []
+ all_ious = []
+ all_low_res_masks = []
+ for img_idx in range(num_images):
+ # Transform input prompts
+ point_coords = (
+ point_coords_batch[img_idx] if point_coords_batch is not None else None
+ )
+ point_labels = (
+ point_labels_batch[img_idx] if point_labels_batch is not None else None
+ )
+ box = box_batch[img_idx] if box_batch is not None else None
+ mask_input = (
+ mask_input_batch[img_idx] if mask_input_batch is not None else None
+ )
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
+ point_coords,
+ point_labels,
+ box,
+ mask_input,
+ normalize_coords,
+ img_idx=img_idx,
+ )
+ masks, iou_predictions, low_res_masks = self._predict(
+ unnorm_coords,
+ labels,
+ unnorm_box,
+ mask_input,
+ multimask_output,
+ return_logits=return_logits,
+ img_idx=img_idx,
+ )
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+ iou_predictions_np = (
+ iou_predictions.squeeze(0).float().detach().cpu().numpy()
+ )
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
+ all_masks.append(masks_np)
+ all_ious.append(iou_predictions_np)
+ all_low_res_masks.append(low_res_masks_np)
+
+ return all_masks, all_ious, all_low_res_masks
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ mask_input: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ normalize_coords=True,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
+
+ # Transform input prompts
+
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
+ point_coords, point_labels, box, mask_input, normalize_coords
+ )
+
+ masks, iou_predictions, low_res_masks = self._predict(
+ unnorm_coords,
+ labels,
+ unnorm_box,
+ mask_input,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
+ return masks_np, iou_predictions_np, low_res_masks_np
+
+ def _prep_prompts(
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
+ ):
+
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = torch.as_tensor(
+ point_coords, dtype=torch.float, device=self.device
+ )
+ unnorm_coords = self._transforms.transform_coords(
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
+ )
+ labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+ if len(unnorm_coords.shape) == 2:
+ unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
+ if box is not None:
+ box = torch.as_tensor(box, dtype=torch.float, device=self.device)
+ unnorm_box = self._transforms.transform_boxes(
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
+ ) # Bx2x2
+ if mask_logits is not None:
+ mask_input = torch.as_tensor(
+ mask_logits, dtype=torch.float, device=self.device
+ )
+ if len(mask_input.shape) == 3:
+ mask_input = mask_input[None, :, :, :]
+ return mask_input, unnorm_coords, labels, unnorm_box
+
+ @torch.no_grad()
+ def _predict(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ mask_input: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ img_idx: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using SAM2Transforms.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
+
+ if point_coords is not None:
+ concat_points = (point_coords, point_labels)
+ else:
+ concat_points = None
+
+ # Embed prompts
+ if boxes is not None:
+ box_coords = boxes.reshape(-1, 2, 2)
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
+ box_labels = box_labels.repeat(boxes.size(0), 1)
+ # we merge "boxes" and "points" into a single "concat_points" input (where
+ # boxes are added at the beginning) to sam_prompt_encoder
+ if concat_points is not None:
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
+ concat_points = (concat_coords, concat_labels)
+ else:
+ concat_points = (box_coords, box_labels)
+
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+ points=concat_points,
+ boxes=None,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ batched_mode = (
+ concat_points is not None and concat_points[0].shape[0] > 1
+ ) # multi object prediction
+ high_res_features = [
+ feat_level[img_idx].unsqueeze(0)
+ for feat_level in self._features["high_res_feats"]
+ ]
+ low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
+ image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=batched_mode,
+ high_res_features=high_res_features,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self._transforms.postprocess_masks(
+ low_res_masks, self._orig_hw[img_idx]
+ )
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
+ if not return_logits:
+ masks = masks > self.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """
+ Returns the image embeddings for the currently set image, with
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
+ """
+ if not self._is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) to generate an embedding."
+ )
+ assert (
+ self._features is not None
+ ), "Features must exist if an image has been set."
+ return self._features["image_embed"]
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_predictor(self) -> None:
+ """
+ Resets the image embeddings and other state variables.
+ """
+ self._is_image_set = False
+ self._features = None
+ self._orig_hw = None
+ self._is_batch = False
diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1b911105a395fd759a0bdef5ef4574f0e7bf301
--- /dev/null
+++ b/sam2/sam2_video_predictor.py
@@ -0,0 +1,900 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import OrderedDict
+
+import torch
+from tqdm import tqdm
+
+from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
+from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
+
+
+class SAM2VideoPredictor(SAM2Base):
+ """The predictor class to handle user interactions and manage inference states."""
+
+ def __init__(
+ self,
+ fill_hole_area=0,
+ # whether to apply non-overlapping constraints on the output object masks
+ non_overlap_masks=False,
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
+ clear_non_cond_mem_around_input=False,
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
+ clear_non_cond_mem_for_multi_obj=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.fill_hole_area = fill_hole_area
+ self.non_overlap_masks = non_overlap_masks
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
+
+ @torch.inference_mode()
+ def init_state(
+ self,
+ video_path=None,
+ images=None,
+ device="cpu",
+ async_loading_frames=False,
+ ):
+ """Initialize a inference state."""
+ if images is not None:
+ images, video_height, video_width = load_video_frames(
+ video_path=None,
+ images=images,
+ image_size=self.image_size,
+ async_loading_frames=async_loading_frames,
+ device=device,
+ )
+ else:
+ images, video_height, video_width = load_video_frames(
+ video_path=video_path,
+ image_size=self.image_size,
+ async_loading_frames=async_loading_frames,
+ device=device,
+ )
+ inference_state = dict()
+ inference_state["images"] = images
+ inference_state["num_frames"] = len(images)
+ # the original video height and width, used for resizing final output scores
+ inference_state["video_height"] = video_height
+ inference_state["video_width"] = video_width
+ inference_state["device"] = device
+ inference_state["storage_device"] = device
+ # inputs on each frame
+ inference_state["point_inputs_per_obj"] = {}
+ inference_state["mask_inputs_per_obj"] = {}
+ # visual features on a small number of recently visited frames for quick interactions
+ inference_state["cached_features"] = {}
+ # values that don't change across frames (so we only need to hold one copy of them)
+ inference_state["constants"] = {}
+ # mapping between client-side object id and model-side object index
+ inference_state["obj_id_to_idx"] = OrderedDict()
+ inference_state["obj_idx_to_id"] = OrderedDict()
+ inference_state["obj_ids"] = []
+ # A storage to hold the model's tracking results and states on each frame
+ inference_state["output_dict"] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
+ inference_state["output_dict_per_obj"] = {}
+ # A temporary storage to hold new outputs when user interact with a frame
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
+ inference_state["temp_output_dict_per_obj"] = {}
+ # Frames that already holds consolidated outputs from click or mask inputs
+ # (we directly use their consolidated outputs during tracking)
+ inference_state["consolidated_frame_inds"] = {
+ "cond_frame_outputs": set(), # set containing frame indices
+ "non_cond_frame_outputs": set(), # set containing frame indices
+ }
+ # metadata for each tracking frame (e.g. which direction it's tracked)
+ inference_state["tracking_has_started"] = False
+ inference_state["frames_already_tracked"] = {}
+ # Warm up the visual backbone and cache the image feature on frame 0
+ self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
+ return inference_state
+
+ def _obj_id_to_idx(self, inference_state, obj_id):
+ """Map client-side object id to model-side object index."""
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+
+ # This is a new object id not sent to the server before. We only allow adding
+ # new objects *before* the tracking starts.
+ allow_new_object = not inference_state["tracking_has_started"]
+ if allow_new_object:
+ # get the next object slot
+ obj_idx = len(inference_state["obj_id_to_idx"])
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
+ # set up input and output structures for this object
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
+ inference_state["output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ return obj_idx
+ else:
+ raise RuntimeError(
+ f"Cannot add new object id {obj_id} after tracking starts. "
+ f"All existing object ids: {inference_state['obj_ids']}. "
+ f"Please call 'reset_state' to restart from scratch."
+ )
+
+ def _obj_idx_to_id(self, inference_state, obj_idx):
+ """Map model-side object index to client-side object id."""
+ return inference_state["obj_idx_to_id"][obj_idx]
+
+ def _get_obj_num(self, inference_state):
+ """Get the total number of unique object ids received so far in this session."""
+ return len(inference_state["obj_idx_to_id"])
+
+ @torch.inference_mode()
+ def add_new_points(
+ self,
+ inference_state,
+ frame_idx,
+ obj_id,
+ points,
+ labels,
+ clear_old_points=True,
+ normalize_coords=True,
+ ):
+ """Add new points to a frame."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+ if not isinstance(points, torch.Tensor):
+ points = torch.tensor(points, dtype=torch.float32)
+ if not isinstance(labels, torch.Tensor):
+ labels = torch.tensor(labels, dtype=torch.int32)
+ if points.dim() == 2:
+ points = points.unsqueeze(0) # add batch dimension
+ if labels.dim() == 1:
+ labels = labels.unsqueeze(0) # add batch dimension
+ if normalize_coords:
+ video_H = inference_state["video_height"]
+ video_W = inference_state["video_width"]
+ points = points / torch.tensor([video_W, video_H]).to(points.device)
+ # scale the (normalized) coordinates by the model's internal image size
+ points = points * self.image_size
+ points = points.to(inference_state["device"])
+ labels = labels.to(inference_state["device"])
+
+ if not clear_old_points:
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
+ else:
+ point_inputs = None
+ point_inputs = concat_points(point_inputs, points, labels)
+
+ point_inputs_per_frame[frame_idx] = point_inputs
+ mask_inputs_per_frame.pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+ # whether to track in reverse time order
+ if is_init_cond_frame:
+ reverse = False
+ else:
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ # Get any previously predicted mask logits on this object and feed it along with
+ # the new clicks into the SAM mask decoder.
+ prev_sam_mask_logits = None
+ # lookup temporary output dict first, which contains the most recent output
+ # (if not found, then lookup conditioning and non-conditioning frame output)
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
+ if prev_out is None:
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
+ if prev_out is None:
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+
+ if prev_out is not None and prev_out["pred_masks"] is not None:
+ prev_sam_mask_logits = prev_out["pred_masks"].to(inference_state["device"])
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
+ current_out, _ = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=point_inputs,
+ mask_inputs=None,
+ reverse=reverse,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ obj_ids = inference_state["obj_ids"]
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ @torch.inference_mode()
+ def add_new_mask(
+ self,
+ inference_state,
+ frame_idx,
+ obj_id,
+ mask,
+ ):
+ """Add new mask to a frame."""
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+ if not isinstance(mask, torch.Tensor):
+ mask = torch.tensor(mask, dtype=torch.bool)
+ assert mask.dim() == 2
+ mask_H, mask_W = mask.shape
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
+
+ # resize the mask if it doesn't match the model's image size
+ if mask_H != self.image_size or mask_W != self.image_size:
+ mask_inputs = torch.nn.functional.interpolate(
+ mask_inputs_orig,
+ size=(self.image_size, self.image_size),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ mask_inputs = (mask_inputs >= 0.5).float()
+ else:
+ mask_inputs = mask_inputs_orig
+
+ mask_inputs_per_frame[frame_idx] = mask_inputs
+ point_inputs_per_frame.pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+ # whether to track in reverse time order
+ if is_init_cond_frame:
+ reverse = False
+ else:
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ current_out, _ = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=None,
+ mask_inputs=mask_inputs,
+ reverse=reverse,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ obj_ids = inference_state["obj_ids"]
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state,
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ consolidate_at_video_res=True,
+ )
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, consolidated_out["pred_masks_video_res"]
+ )
+ return frame_idx, obj_ids, video_res_masks
+
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
+ """
+ Resize the object scores to the original video resolution (video_res_masks)
+ and apply non-overlapping constraints for final output.
+ """
+ device = inference_state["device"]
+ video_H = inference_state["video_height"]
+ video_W = inference_state["video_width"]
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
+ if any_res_masks.shape[-2:] == (video_H, video_W):
+ video_res_masks = any_res_masks
+ else:
+ video_res_masks = torch.nn.functional.interpolate(
+ any_res_masks,
+ size=(video_H, video_W),
+ mode="bilinear",
+ align_corners=False,
+ )
+ if self.non_overlap_masks:
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
+ return any_res_masks, video_res_masks
+
+ def _consolidate_temp_output_across_obj(
+ self,
+ inference_state,
+ frame_idx,
+ is_cond,
+ run_mem_encoder,
+ consolidate_at_video_res=False,
+ ):
+ """
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
+ a frame into a single output for all objects, including
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
+ (if they don't exist in `output_dict_per_obj` for this frame);
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
+ on the object scores.
+ """
+ batch_size = self._get_obj_num(inference_state)
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Optionally, we allow consolidating the temporary outputs at the original
+ # video resolution (to provide a better editing experience for mask prompts).
+ if consolidate_at_video_res:
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
+ consolidated_H = inference_state["video_height"]
+ consolidated_W = inference_state["video_width"]
+ consolidated_mask_key = "pred_masks_video_res"
+ else:
+ consolidated_H = consolidated_W = self.image_size // 4
+ consolidated_mask_key = "pred_masks"
+
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
+ # will be added when rerunning the memory encoder after applying non-overlapping
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
+ consolidated_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ consolidated_mask_key: torch.full(
+ size=(batch_size, 1, consolidated_H, consolidated_W),
+ fill_value=NO_OBJ_SCORE,
+ dtype=torch.float32,
+ device=inference_state["storage_device"],
+ ),
+ "obj_ptr": torch.full(
+ size=(batch_size, self.hidden_dim),
+ fill_value=NO_OBJ_SCORE,
+ dtype=torch.float32,
+ device=inference_state["device"],
+ ),
+ }
+ empty_mask_ptr = None
+ for obj_idx in range(batch_size):
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
+ # we fall back and look up its previous output in "output_dict_per_obj".
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
+ # "output_dict_per_obj" to find a previous output for this object.
+ if out is None:
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
+ if out is None:
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
+ # placeholder above) and set its object pointer to be a dummy pointer.
+ if out is None:
+ # Fill in dummy object pointers for those objects without any inputs or
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
+ # i.e. when we need to build the memory for tracking).
+ if run_mem_encoder:
+ if empty_mask_ptr is None:
+ empty_mask_ptr = self._get_empty_mask_ptr(
+ inference_state, frame_idx
+ )
+ # fill object pointer with a dummy pointer (based on an empty mask)
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
+ continue
+ # Add the temporary object output mask to consolidated output mask
+ obj_mask = out["pred_masks"]
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
+ else:
+ # Resize first if temporary object mask has a different resolution
+ resized_obj_mask = torch.nn.functional.interpolate(
+ obj_mask,
+ size=consolidated_pred_masks.shape[-2:],
+ mode="bilinear",
+ align_corners=False,
+ )
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
+
+ # Optionally, apply non-overlapping constraints on the consolidated scores
+ # and rerun the memory encoder
+ if run_mem_encoder:
+ device = inference_state["device"]
+ high_res_masks = torch.nn.functional.interpolate(
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ if self.non_overlap_masks_for_mem_enc:
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
+ inference_state=inference_state,
+ frame_idx=frame_idx,
+ batch_size=batch_size,
+ high_res_masks=high_res_masks,
+ is_mask_from_pts=True, # these frames are what the user interacted with
+ )
+ consolidated_out["maskmem_features"] = maskmem_features
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
+
+ return consolidated_out
+
+ def _get_empty_mask_ptr(self, inference_state, frame_idx):
+ """Get a dummy object pointer based on an empty mask on the current frame."""
+ # A dummy (empty) mask with a single object
+ batch_size = 1
+ mask_inputs = torch.zeros(
+ (batch_size, 1, self.image_size, self.image_size),
+ dtype=torch.float32,
+ device=inference_state["device"],
+ )
+
+ # Retrieve correct image features
+ (
+ _,
+ _,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+ # Feed the empty mask and image feature above to get a dummy object pointer
+ current_out = self.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=True,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=None,
+ mask_inputs=mask_inputs,
+ output_dict={},
+ num_frames=inference_state["num_frames"],
+ track_in_reverse=False,
+ run_mem_encoder=False,
+ prev_sam_mask_logits=None,
+ )
+ return current_out["obj_ptr"]
+
+ @torch.inference_mode()
+ def propagate_in_video_preflight(self, inference_state):
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
+ # Tracking has started and we don't allow adding new objects until session is reset.
+ inference_state["tracking_has_started"] = True
+ batch_size = self._get_obj_num(inference_state)
+
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
+ # add them into "output_dict".
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+ output_dict = inference_state["output_dict"]
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
+ # temporary outputs have been added (either in this call or any previous calls
+ # to `propagate_in_video_preflight`).
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ for is_cond in [False, True]:
+ # Separately consolidate conditioning and non-conditioning temp outptus
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Find all the frames that contain temporary outputs for any objects
+ # (these should be the frames that have just received clicks for mask inputs
+ # via `add_new_points` or `add_new_mask`)
+ temp_frame_inds = set()
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
+ # consolidate the temprary output across all objects on this frame
+ for frame_idx in temp_frame_inds:
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
+ )
+ # merge them into "output_dict" and also create per-object slices
+ output_dict[storage_key][frame_idx] = consolidated_out
+ self._add_output_per_object(
+ inference_state, frame_idx, consolidated_out, storage_key
+ )
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
+ )
+ if clear_non_cond_mem:
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
+
+ # clear temporary outputs in `temp_output_dict_per_obj`
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ obj_temp_output_dict[storage_key].clear()
+
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
+ # output on the same frame in "non_cond_frame_outputs"
+ for frame_idx in output_dict["cond_frame_outputs"]:
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+ assert frame_idx in output_dict["cond_frame_outputs"]
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
+ # with either points or mask inputs (which should be true under a correct workflow).
+ all_consolidated_frame_inds = (
+ consolidated_frame_inds["cond_frame_outputs"]
+ | consolidated_frame_inds["non_cond_frame_outputs"]
+ )
+ input_frames_inds = set()
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
+ input_frames_inds.update(point_inputs_per_frame.keys())
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
+ input_frames_inds.update(mask_inputs_per_frame.keys())
+ assert all_consolidated_frame_inds == input_frames_inds
+
+ @torch.inference_mode()
+ def propagate_in_video(
+ self,
+ inference_state,
+ start_frame_idx=None,
+ max_frame_num_to_track=None,
+ reverse=False,
+ ):
+ """Propagate the input points across frames to track in the entire video."""
+ self.propagate_in_video_preflight(inference_state)
+
+ output_dict = inference_state["output_dict"]
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+ obj_ids = inference_state["obj_ids"]
+ num_frames = inference_state["num_frames"]
+ batch_size = self._get_obj_num(inference_state)
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ raise RuntimeError("No points are provided; please add points first")
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
+ )
+
+ # set start index, end index, and processing order
+ if start_frame_idx is None:
+ # default: start from the earliest frame with input points
+ start_frame_idx = min(output_dict["cond_frame_outputs"])
+ if max_frame_num_to_track is None:
+ # default: track all the frames in the video
+ max_frame_num_to_track = num_frames
+ if reverse:
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
+ if start_frame_idx > 0:
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
+ else:
+ processing_order = [] # skip reverse tracking if starting from frame 0
+ else:
+ end_frame_idx = min(
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
+ )
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
+
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
+ # We skip those frames already in consolidated outputs (these are frames
+ # that received input clicks or mask). Note that we cannot directly run
+ # batched forward on them via `_run_single_frame_inference` because the
+ # number of clicks on each object might be different.
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+ storage_key = "cond_frame_outputs"
+ current_out = output_dict[storage_key][frame_idx]
+ pred_masks = current_out["pred_masks"]
+ if clear_non_cond_mem:
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
+ storage_key = "non_cond_frame_outputs"
+ current_out = output_dict[storage_key][frame_idx]
+ pred_masks = current_out["pred_masks"]
+ else:
+ storage_key = "non_cond_frame_outputs"
+ current_out, pred_masks = self._run_single_frame_inference(
+ inference_state=inference_state,
+ output_dict=output_dict,
+ frame_idx=frame_idx,
+ batch_size=batch_size,
+ is_init_cond_frame=False,
+ point_inputs=None,
+ mask_inputs=None,
+ reverse=reverse,
+ run_mem_encoder=True,
+ )
+ output_dict[storage_key][frame_idx] = current_out
+ # Create slices of per-object outputs for subsequent interaction with each
+ # individual object after tracking.
+ self._add_output_per_object(
+ inference_state, frame_idx, current_out, storage_key
+ )
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
+
+ # Resize the output mask to the original video resolution (we directly use
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
+ _, video_res_masks = self._get_orig_video_res_output(
+ inference_state, pred_masks
+ )
+ yield frame_idx, obj_ids, video_res_masks
+
+ def _add_output_per_object(
+ self, inference_state, frame_idx, current_out, storage_key
+ ):
+ """
+ Split a multi-object output into per-object output slices and add them into
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
+ """
+ maskmem_features = current_out["maskmem_features"]
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
+
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
+
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
+ obj_slice = slice(obj_idx, obj_idx + 1)
+ obj_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ "pred_masks": current_out["pred_masks"][obj_slice],
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
+ }
+ if maskmem_features is not None:
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
+ if maskmem_pos_enc is not None:
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
+ obj_output_dict[storage_key][frame_idx] = obj_out
+
+ @torch.inference_mode()
+ def reset_state(self, inference_state):
+ """Remove all input points or mask in all frames throughout the video."""
+ self._reset_tracking_results(inference_state)
+ # Remove all object ids
+ inference_state["obj_id_to_idx"].clear()
+ inference_state["obj_idx_to_id"].clear()
+ inference_state["obj_ids"].clear()
+ inference_state["point_inputs_per_obj"].clear()
+ inference_state["mask_inputs_per_obj"].clear()
+ inference_state["output_dict_per_obj"].clear()
+ inference_state["temp_output_dict_per_obj"].clear()
+
+ def _reset_tracking_results(self, inference_state):
+ """Reset all tracking inputs and results across the videos."""
+ for v in inference_state["point_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["mask_inputs_per_obj"].values():
+ v.clear()
+ for v in inference_state["output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ for v in inference_state["temp_output_dict_per_obj"].values():
+ v["cond_frame_outputs"].clear()
+ v["non_cond_frame_outputs"].clear()
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
+ inference_state["tracking_has_started"] = False
+ inference_state["frames_already_tracked"].clear()
+
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
+ """Compute the image features on a given frame."""
+ # Look up in the cache first
+ image, backbone_out = inference_state["cached_features"].get(
+ frame_idx, (None, None)
+ )
+ if backbone_out is None:
+ # Cache miss -- we will run inference on a single image
+ image = (
+ inference_state["images"][frame_idx]
+ .to(inference_state["device"])
+ .float()
+ .unsqueeze(0)
+ )
+ backbone_out = self.forward_image(image)
+ # Cache the most recent frame's feature (for repeated interactions with
+ # a frame; we can use an LRU cache for more frames in the future).
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
+
+ # expand the features to have the same dimension as the number of objects
+ expanded_image = image.expand(batch_size, -1, -1, -1)
+ expanded_backbone_out = {
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
+ }
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
+ expanded_backbone_out["backbone_fpn"][i] = feat.expand(
+ batch_size, -1, -1, -1
+ )
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
+ pos = pos.expand(batch_size, -1, -1, -1)
+ expanded_backbone_out["vision_pos_enc"][i] = pos
+
+ features = self._prepare_backbone_features(expanded_backbone_out)
+ features = (expanded_image,) + features
+ return features
+
+ def _run_single_frame_inference(
+ self,
+ inference_state,
+ output_dict,
+ frame_idx,
+ batch_size,
+ is_init_cond_frame,
+ point_inputs,
+ mask_inputs,
+ reverse,
+ run_mem_encoder,
+ prev_sam_mask_logits=None,
+ ):
+ """Run tracking on a single frame based on current inputs and previous memory."""
+ # Retrieve correct image features
+ (
+ _,
+ _,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+ # point and mask should not appear as input simultaneously on the same frame
+ assert point_inputs is None or mask_inputs is None
+ current_out = self.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ output_dict=output_dict,
+ num_frames=inference_state["num_frames"],
+ track_in_reverse=reverse,
+ run_mem_encoder=run_mem_encoder,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+
+ # optionally offload the output to CPU memory to save GPU space
+ storage_device = inference_state["storage_device"]
+ maskmem_features = current_out["maskmem_features"]
+ if maskmem_features is not None:
+ maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+ pred_masks_gpu = current_out["pred_masks"]
+ # potentially fill holes in the predicted masks
+ if self.fill_hole_area > 0:
+ pred_masks_gpu = fill_holes_in_mask_scores(
+ pred_masks_gpu, self.fill_hole_area
+ )
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
+ obj_ptr = current_out["obj_ptr"]
+ # make a compact version of this frame's output to reduce the state size
+ compact_current_out = {
+ "maskmem_features": maskmem_features,
+ "maskmem_pos_enc": maskmem_pos_enc,
+ "pred_masks": pred_masks,
+ "obj_ptr": obj_ptr,
+ }
+ return compact_current_out, pred_masks_gpu
+
+ def _run_memory_encoder(
+ self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
+ ):
+ """
+ Run the memory encoder on `high_res_masks`. This is usually after applying
+ non-overlapping constraints to object scores. Since their scores changed, their
+ memory also need to be computed again with the memory encoder.
+ """
+ # Retrieve correct image features
+ _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
+ inference_state, frame_idx, batch_size
+ )
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks,
+ is_mask_from_pts=is_mask_from_pts,
+ )
+
+ # optionally offload the output to CPU memory to save GPU space
+ storage_device = inference_state["storage_device"]
+ maskmem_features = maskmem_features.to(torch.bfloat16)
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
+ )
+ return maskmem_features, maskmem_pos_enc
+
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
+ """
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
+ a constant in the inference session to reduce session storage size.
+ """
+ model_constants = inference_state["constants"]
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ if out_maskmem_pos_enc is not None:
+ if "maskmem_pos_enc" not in model_constants:
+ assert isinstance(out_maskmem_pos_enc, list)
+ # only take the slice for one object, since it's same across objects
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
+ # expand the cached maskmem_pos_enc to the actual batch size
+ batch_size = out_maskmem_pos_enc[0].size(0)
+ expanded_maskmem_pos_enc = [
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
+ ]
+ else:
+ expanded_maskmem_pos_enc = None
+ return expanded_maskmem_pos_enc
+
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
+ """
+ Remove the non-conditioning memory around the input frame. When users provide
+ correction clicks, the surrounding frames' non-conditioning memories can still
+ contain outdated object appearance information and could confuse the model.
+
+ This method clears those non-conditioning memories surrounding the interacted
+ frame to avoid giving the model both old and new information about the object.
+ """
+ r = self.memory_temporal_stride_for_eval
+ frame_idx_begin = frame_idx - r * self.num_maskmem
+ frame_idx_end = frame_idx + r * self.num_maskmem
+ output_dict = inference_state["output_dict"]
+ non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
+ for t in range(frame_idx_begin, frame_idx_end + 1):
+ non_cond_frame_outputs.pop(t, None)
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
diff --git a/sam2/utils/__init__.py b/sam2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4d9f56d9ac50c1c34d742f757c488b06b0c2b4e
--- /dev/null
+++ b/sam2/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .download import download_weights
diff --git a/sam2/utils/__pycache__/__init__.cpython-311.pyc b/sam2/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f420b43150fff84bc29294a177347c686886ae8e
Binary files /dev/null and b/sam2/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sam2/utils/__pycache__/download.cpython-311.pyc b/sam2/utils/__pycache__/download.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c347a06cb6c1b0f14fded33a11ffdef72738f50
Binary files /dev/null and b/sam2/utils/__pycache__/download.cpython-311.pyc differ
diff --git a/sam2/utils/__pycache__/misc.cpython-311.pyc b/sam2/utils/__pycache__/misc.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65e5d90f57b978985524fa52926f6151f4550ff4
Binary files /dev/null and b/sam2/utils/__pycache__/misc.cpython-311.pyc differ
diff --git a/sam2/utils/__pycache__/transforms.cpython-311.pyc b/sam2/utils/__pycache__/transforms.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ffee0d81bb9a8bcc616c68882c205364d2c5b0c
Binary files /dev/null and b/sam2/utils/__pycache__/transforms.cpython-311.pyc differ
diff --git a/sam2/utils/amg.py b/sam2/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e
--- /dev/null
+++ b/sam2/utils/amg.py
@@ -0,0 +1,348 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+import numpy as np
+import torch
+
+# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
+
+
+class MaskData:
+ """
+ A structure for storing masks and their related data in batched format.
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def cat(self, new_stats: "MaskData") -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.float().detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in args
+ ), "Batched iteration must have inputs of all the same size."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle["size"]
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+ """
+ Computes the stability score for a batch of masks. The stability
+ score is the IoU between the binary masks obtained by thresholding
+ the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ unions = (
+ (masks > (mask_threshold - threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(
+ n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer
+ has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+ mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+ """
+ Removes small disconnected regions and holes in a mask. Returns the
+ mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
diff --git a/sam2/utils/download.py b/sam2/utils/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..577a0b16e6bba6a70dd94cc9360bc21ce5fe4586
--- /dev/null
+++ b/sam2/utils/download.py
@@ -0,0 +1,36 @@
+import os
+import subprocess
+from typing import List
+
+import pytest
+
+
+@pytest.fixture
+def download_weights(output_directory: str = "artifacts") -> None:
+ base_url: str = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
+ file_names: List[str] = [
+ "sam2_hiera_tiny.pt",
+ "sam2_hiera_small.pt",
+ "sam2_hiera_base_plus.pt",
+ "sam2_hiera_large.pt",
+ ]
+
+ if not os.path.exists(output_directory):
+ os.makedirs(output_directory)
+
+ for file_name in file_names:
+ file_path = os.path.join(output_directory, file_name)
+ if not os.path.exists(file_path):
+ url = f"{base_url}{file_name}"
+ command = ["wget", url, "-P", output_directory]
+ try:
+ result = subprocess.run(
+ command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+ print(f"Download of {file_name} completed successfully.")
+ print(result.stdout.decode())
+ except subprocess.CalledProcessError as e:
+ print(f"An error occurred during the download of {file_name}.")
+ print(e.stderr.decode())
+ else:
+ print(f"{file_name} already exists. Skipping download.")
diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..b860df46c33b8377a0125a581ec17f61d424766f
--- /dev/null
+++ b/sam2/utils/misc.py
@@ -0,0 +1,244 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import warnings
+from threading import Thread
+from typing import Dict, List, Union
+
+import numpy as np
+import torch
+from PIL import Image
+from torch.nn.attention import SDPBackend
+from einops import rearrange
+from tqdm import tqdm
+import torch.nn.functional as F
+
+VARIANTS: List[str] = ["tiny", "small", "base_plus", "large"]
+
+variant_to_config_mapping: Dict[str, str] = {
+ "tiny": "sam2_hiera_t.yaml",
+ "small": "sam2_hiera_s.yaml",
+ "base_plus": "sam2_hiera_b+.yaml",
+ "large": "sam2_hiera_l.yaml",
+}
+
+
+def get_sdp_backends(dropout_p: float) -> Union[List[SDPBackend], SDPBackend]:
+ backends = []
+ if torch.cuda.is_available():
+ use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
+ pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
+
+ if torch.cuda.get_device_properties(0).major < 7:
+ backends.append(SDPBackend.EFFICIENT_ATTENTION)
+
+ if use_flash_attn:
+ backends.append(SDPBackend.FLASH_ATTENTION)
+
+ if pytorch_version < (2, 2) or not use_flash_attn:
+ backends.append(SDPBackend.MATH)
+
+ if (
+ SDPBackend.EFFICIENT_ATTENTION in backends and dropout_p > 0.0
+ ) and SDPBackend.MATH not in backends:
+ backends.append(SDPBackend.MATH)
+
+ else:
+ backends.extend([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])
+
+ return backends
+
+
+def get_connected_components(mask):
+ """
+ Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
+
+ Inputs:
+ - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
+ background.
+
+ Outputs:
+ - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
+ for foreground pixels and 0 for background pixels.
+ - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
+ components for foreground pixels and 0 for background pixels.
+ """
+ from sam2 import _C
+
+ return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
+
+
+def mask_to_box(masks: torch.Tensor):
+ """
+ compute bounding box given an input mask
+
+ Inputs:
+ - masks: [B, 1, H, W] boxes, dtype=torch.Tensor
+
+ Returns:
+ - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
+ """
+ B, _, h, w = masks.shape
+ device = masks.device
+ xs = torch.arange(w, device=device, dtype=torch.int32)
+ ys = torch.arange(h, device=device, dtype=torch.int32)
+ grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
+ grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
+ grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
+ min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
+ max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
+ min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
+ max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
+ bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
+
+ return bbox_coords
+
+
+def _load_img_as_tensor(img_path, image_size):
+ img_pil = Image.open(img_path)
+ img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
+ if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
+ img_np = img_np / 255.0
+ else:
+ raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
+ img = torch.from_numpy(img_np).permute(2, 0, 1)
+ video_width, video_height = img_pil.size # the original video size
+ return img, video_height, video_width
+
+
+class AsyncVideoFrameLoader:
+ """
+ A list of video frames to be load asynchronously without blocking session start.
+ """
+
+ def __init__(self, img_paths, image_size, img_mean, img_std, device):
+ self.img_paths = img_paths
+ self.image_size = image_size
+ self.img_mean = img_mean
+ self.img_std = img_std
+ self.device = device
+ # items in `self._images` will be loaded asynchronously
+ self.images = [None] * len(img_paths)
+ # catch and raise any exceptions in the async loading thread
+ self.exception = None
+ # video_height and video_width be filled when loading the first image
+ self.video_height = None
+ self.video_width = None
+
+ # load the first frame to fill video_height and video_width and also
+ # to cache it (since it's most likely where the user will click)
+ self.__getitem__(0)
+
+ # load the rest of frames asynchronously without blocking the session start
+ def _load_frames():
+ try:
+ for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
+ self.__getitem__(n)
+ except Exception as e:
+ self.exception = e
+
+ self.thread = Thread(target=_load_frames, daemon=True)
+ self.thread.start()
+
+ def __getitem__(self, index):
+ if self.exception is not None:
+ raise RuntimeError("Failure in frame loading thread") from self.exception
+
+ img = self.images[index]
+ if img is not None:
+ return img
+
+ img, video_height, video_width = _load_img_as_tensor(
+ self.img_paths[index], self.image_size
+ )
+ self.video_height = video_height
+ self.video_width = video_width
+ # normalize by mean and std
+ img -= self.img_mean
+ img /= self.img_std
+ img = img.to(self.device)
+ self.images[index] = img
+ return img
+
+ def __len__(self):
+ return len(self.images)
+
+
+def load_video_frames(
+ video_path,
+ image_size,
+ images=None,
+ img_mean=(0.485, 0.456, 0.406),
+ img_std=(0.229, 0.224, 0.225),
+ async_loading_frames=False,
+ device="cpu",
+):
+ """
+ Load the video frames from a directory of JPEG files (".jpg" format).
+
+ You can load a frame asynchronously by setting `async_loading_frames` to `True`.
+ """
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+
+ if images is not None:
+ images = torch.from_numpy(images).float()
+ images = rearrange(images, "f h w c -> f c h w")
+ images = F.interpolate(images, (image_size, image_size), mode="bilinear")
+ video_height, video_width = images.shape[2:]
+ else:
+ if isinstance(video_path, str) and os.path.isdir(video_path):
+ jpg_folder = video_path
+ else:
+ raise NotImplementedError("Only JPEG frames are supported at this moment")
+
+ frame_names = [
+ p
+ for p in os.listdir(jpg_folder)
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
+ ]
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
+ num_frames = len(frame_names)
+ if num_frames == 0:
+ raise RuntimeError(f"no images found in {jpg_folder}")
+ img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
+
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
+ for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
+ images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
+ images = images.to(device)
+ img_mean = img_mean.to(device)
+ img_std = img_std.to(device)
+ # normalize by mean and std
+ images -= img_mean
+ images /= img_std
+ return images, video_height, video_width
+
+
+def fill_holes_in_mask_scores(mask, max_area):
+ """
+ A post processor to fill small holes in mask scores with area under `max_area`.
+ """
+ # Holes are those connected components in background with area <= self.max_area
+ # (background regions are those with mask scores <= 0)
+ assert max_area > 0, "max_area must be positive"
+ labels, areas = get_connected_components(mask <= 0)
+ is_hole = (labels > 0) & (areas <= max_area)
+ # We fill holes with a small positive mask score (0.1) to change them to foreground.
+ mask = torch.where(is_hole, 0.1, mask)
+ return mask
+
+
+def concat_points(old_point_inputs, new_points, new_labels):
+ """Add new points and labels to previous point inputs (add at the end)."""
+ if old_point_inputs is None:
+ points, labels = new_points, new_labels
+ else:
+ points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
+ labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
+
+ return {"point_coords": points, "point_labels": labels}
diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..d05cd3e5ebacadca10037c138cbeebfa6b89adab
--- /dev/null
+++ b/sam2/utils/transforms.py
@@ -0,0 +1,99 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Normalize, Resize, ToTensor
+
+
+class SAM2Transforms(nn.Module):
+ def __init__(
+ self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
+ ):
+ """
+ Transforms for SAM2.
+ """
+ super().__init__()
+ self.resolution = resolution
+ self.mask_threshold = mask_threshold
+ self.max_hole_area = max_hole_area
+ self.max_sprinkle_area = max_sprinkle_area
+ self.mean = [0.485, 0.456, 0.406]
+ self.std = [0.229, 0.224, 0.225]
+ self.to_tensor = ToTensor()
+ self.transforms = torch.jit.script(
+ nn.Sequential(
+ Resize((self.resolution, self.resolution)),
+ Normalize(self.mean, self.std),
+ )
+ )
+
+ def __call__(self, x):
+ x = self.to_tensor(x)
+ return self.transforms(x)
+
+ def forward_batch(self, img_list):
+ img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
+ img_batch = torch.stack(img_batch, dim=0)
+ return img_batch
+
+ def transform_coords(
+ self, coords: torch.Tensor, normalize=False, orig_hw=None
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
+ If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
+
+ Returns
+ Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
+ """
+ if normalize:
+ assert orig_hw is not None
+ h, w = orig_hw
+ coords = coords.clone()
+ coords[..., 0] = coords[..., 0] / w
+ coords[..., 1] = coords[..., 1] / h
+
+ coords = coords * self.resolution # unnormalize coords
+ return coords
+
+ def transform_boxes(
+ self, boxes: torch.Tensor, normalize=False, orig_hw=None
+ ) -> torch.Tensor:
+ """
+ Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
+ if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
+ """
+ boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
+ return boxes
+
+ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
+ """
+ Perform PostProcessing on output masks.
+ """
+ from sam2.utils.misc import get_connected_components
+
+ masks = masks.float()
+ if self.max_hole_area > 0:
+ # Holes are those connected components in background with area <= self.fill_hole_area
+ # (background regions are those with mask scores <= self.mask_threshold)
+ mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
+ labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
+ is_hole = (labels > 0) & (areas <= self.max_hole_area)
+ is_hole = is_hole.reshape_as(masks)
+ # We fill holes with a small positive mask score (10.0) to change them to foreground.
+ masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
+
+ if self.max_sprinkle_area > 0:
+ labels, areas = get_connected_components(mask_flat > self.mask_threshold)
+ is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
+ is_hole = is_hole.reshape_as(masks)
+ # We fill holes with negative mask score (-10.0) to change them to background.
+ masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
+
+ masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
+ return masks
diff --git a/sam2/utils/visualization.py b/sam2/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab5d5460f9466098fc11c54c735d7545e5290184
--- /dev/null
+++ b/sam2/utils/visualization.py
@@ -0,0 +1,68 @@
+from typing import Optional
+
+import numpy as np
+from PIL import Image
+
+
+def show_masks(
+ image: np.ndarray,
+ masks: np.ndarray,
+ scores: Optional[np.ndarray],
+ alpha: Optional[float] = 0.5,
+ display_image: Optional[bool] = False,
+ only_best: Optional[bool] = True,
+ autogenerated_mask: Optional[bool] = False,
+) -> Image.Image:
+ if scores is not None:
+ # sort masks by their scores
+ sorted_ind = np.argsort(scores)[::-1]
+ masks = masks[sorted_ind]
+
+ if autogenerated_mask:
+ masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
+ else:
+ # get mask dimensions
+ h, w = masks.shape[-2:]
+
+ if display_image:
+ output_image = Image.fromarray(image)
+ else:
+ # create a new blank image to superimpose masks
+ if autogenerated_mask:
+ output_image = Image.new(
+ mode="RGBA",
+ size=(
+ masks[0]["segmentation"].shape[1],
+ masks[0]["segmentation"].shape[0],
+ ),
+ color=(0, 0, 0),
+ )
+ else:
+ output_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0))
+
+ for i, mask in enumerate(masks):
+ if not autogenerated_mask:
+ if mask.ndim > 2: # type: ignore
+ mask = mask.squeeze() # type: ignore
+ else:
+ mask = mask["segmentation"]
+ # Generate a random color with specified alpha value
+ color = np.concatenate(
+ (np.random.randint(0, 256, size=3), [int(alpha * 255)]), axis=0
+ )
+
+ # Create an RGBA image for the mask
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("L")
+ mask_colored = Image.new("RGBA", mask_image.size, tuple(color))
+ mask_image = Image.composite(
+ mask_colored, Image.new("RGBA", mask_image.size), mask_image
+ )
+
+ # Overlay mask on the output image
+ output_image = Image.alpha_composite(output_image, mask_image)
+
+ # Exit if specified to only display the best mask
+ if only_best:
+ break
+
+ return output_image
diff --git a/vace/__init__.py b/vace/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5ba686bc2105181216101c39ae4412c247aadbe
--- /dev/null
+++ b/vace/__init__.py
@@ -0,0 +1,2 @@
+# -*- coding: utf-8 -*-
+from . import models
\ No newline at end of file
diff --git a/vace/__pycache__/__init__.cpython-311.pyc b/vace/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2199364bd7b32768f4ce9c31bf173ad7bfe16946
Binary files /dev/null and b/vace/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vace/models/__init__.py b/vace/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d99a0a60c32e0c5b194c8bc62125f2ef4ab7e319
--- /dev/null
+++ b/vace/models/__init__.py
@@ -0,0 +1,8 @@
+# -*- coding: utf-8 -*-
+from . import utils
+
+try:
+ from . import wan
+except ImportError as e:
+ print("Warning: failed to importing 'wan'. Please install its dependencies with:")
+ print("pip install wan@git+https://github.com/Wan-Video/Wan2.1")
diff --git a/vace/models/__pycache__/__init__.cpython-311.pyc b/vace/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d91315b8bdc5373327ab3018ffab6645e2d8fb38
Binary files /dev/null and b/vace/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vace/models/utils/__init__.py b/vace/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d95c410fcc2c3e35b127c99e988a00ee1ad85a19
--- /dev/null
+++ b/vace/models/utils/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from .preprocessor import VaceVideoProcessor
\ No newline at end of file
diff --git a/vace/models/utils/__pycache__/__init__.cpython-311.pyc b/vace/models/utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd77e5c03f90365cee502e2a0b0ae65bb378676e
Binary files /dev/null and b/vace/models/utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vace/models/utils/__pycache__/preprocessor.cpython-311.pyc b/vace/models/utils/__pycache__/preprocessor.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d80e69e8b5c58029ae2e81663dd4a0bde8e81fe3
Binary files /dev/null and b/vace/models/utils/__pycache__/preprocessor.cpython-311.pyc differ
diff --git a/vace/models/utils/preprocessor.py b/vace/models/utils/preprocessor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0788111a7b79fda3070a2ab8372956c0726af26
--- /dev/null
+++ b/vace/models/utils/preprocessor.py
@@ -0,0 +1,271 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import numpy as np
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+
+
+class VaceImageProcessor(object):
+ def __init__(self, downsample=None, seq_len=None):
+ self.downsample = downsample
+ self.seq_len = seq_len
+
+ def _pillow_convert(self, image, cvt_type='RGB'):
+ if image.mode != cvt_type:
+ if image.mode == 'P':
+ image = image.convert(f'{cvt_type}A')
+ if image.mode == f'{cvt_type}A':
+ bg = Image.new(cvt_type,
+ size=(image.width, image.height),
+ color=(255, 255, 255))
+ bg.paste(image, (0, 0), mask=image)
+ image = bg
+ else:
+ image = image.convert(cvt_type)
+ return image
+
+ def _load_image(self, img_path):
+ if img_path is None or img_path == '':
+ return None
+ img = Image.open(img_path)
+ img = self._pillow_convert(img)
+ return img
+
+ def _resize_crop(self, img, oh, ow, normalize=True):
+ """
+ Resize, center crop, convert to tensor, and normalize.
+ """
+ # resize and crop
+ iw, ih = img.size
+ if iw != ow or ih != oh:
+ # resize
+ scale = max(ow / iw, oh / ih)
+ img = img.resize(
+ (round(scale * iw), round(scale * ih)),
+ resample=Image.Resampling.LANCZOS
+ )
+ assert img.width >= ow and img.height >= oh
+
+ # center crop
+ x1 = (img.width - ow) // 2
+ y1 = (img.height - oh) // 2
+ img = img.crop((x1, y1, x1 + ow, y1 + oh))
+
+ # normalize
+ if normalize:
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
+ return img
+
+ def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
+ return self._resize_crop(img, oh, ow, normalize)
+
+ def load_image(self, data_key, **kwargs):
+ return self.load_image_batch(data_key, **kwargs)
+
+ def load_image_pair(self, data_key, data_key2, **kwargs):
+ return self.load_image_batch(data_key, data_key2, **kwargs)
+
+ def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
+ seq_len = self.seq_len if seq_len is None else seq_len
+ imgs = []
+ for data_key in data_key_batch:
+ img = self._load_image(data_key)
+ imgs.append(img)
+ w, h = imgs[0].size
+ dh, dw = self.downsample[1:]
+
+ # compute output size
+ scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
+ oh = int(h * scale) // dh * dh
+ ow = int(w * scale) // dw * dw
+ assert (oh // dh) * (ow // dw) <= seq_len
+ imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
+ return *imgs, (oh, ow)
+
+
+class VaceVideoProcessor(object):
+ def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
+ self.downsample = downsample
+ self.min_area = min_area
+ self.max_area = max_area
+ self.min_fps = min_fps
+ self.max_fps = max_fps
+ self.zero_start = zero_start
+ self.keep_last = keep_last
+ self.seq_len = seq_len
+ assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
+
+ def set_area(self, area):
+ self.min_area = area
+ self.max_area = area
+
+ def set_seq_len(self, seq_len):
+ self.seq_len = seq_len
+
+ @staticmethod
+ def resize_crop(video: torch.Tensor, oh: int, ow: int):
+ """
+ Resize, center crop and normalize for decord loaded video (torch.Tensor type)
+
+ Parameters:
+ video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
+ oh - target height (int)
+ ow - target width (int)
+
+ Returns:
+ The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
+
+ Raises:
+ """
+ # permute ([t, h, w, c] -> [t, c, h, w])
+ video = video.permute(0, 3, 1, 2)
+
+ # resize and crop
+ ih, iw = video.shape[2:]
+ if ih != oh or iw != ow:
+ # resize
+ scale = max(ow / iw, oh / ih)
+ video = F.interpolate(
+ video,
+ size=(round(scale * ih), round(scale * iw)),
+ mode='bicubic',
+ antialias=True
+ )
+ assert video.size(3) >= ow and video.size(2) >= oh
+
+ # center crop
+ x1 = (video.size(3) - ow) // 2
+ y1 = (video.size(2) - oh) // 2
+ video = video[:, :, y1:y1 + oh, x1:x1 + ow]
+
+ # permute ([t, c, h, w] -> [c, t, h, w]) and normalize
+ video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
+ return video
+
+ def _video_preprocess(self, video, oh, ow):
+ return self.resize_crop(video, oh, ow)
+
+ def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
+ target_fps = min(fps, self.max_fps)
+ duration = frame_timestamps[-1].mean()
+ x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
+ h, w = y2 - y1, x2 - x1
+ ratio = h / w
+ df, dh, dw = self.downsample
+
+ area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
+ of = min(
+ (int(duration * target_fps) - 1) // df + 1,
+ int(self.seq_len / area_z)
+ )
+
+ # deduce target shape of the [latent video]
+ target_area_z = min(area_z, int(self.seq_len / of))
+ oh = round(np.sqrt(target_area_z * ratio))
+ ow = int(target_area_z / oh)
+ of = (of - 1) * df + 1
+ oh *= dh
+ ow *= dw
+
+ # sample frame ids
+ target_duration = of / target_fps
+ begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
+ timestamps = np.linspace(begin, begin + target_duration, of)
+ frame_ids = np.argmax(np.logical_and(
+ timestamps[:, None] >= frame_timestamps[None, :, 0],
+ timestamps[:, None] < frame_timestamps[None, :, 1]
+ ), axis=1).tolist()
+ return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
+
+ def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
+ duration = frame_timestamps[-1].mean()
+ x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
+ h, w = y2 - y1, x2 - x1
+ ratio = h / w
+ df, dh, dw = self.downsample
+
+ area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
+ of = min(
+ (len(frame_timestamps) - 1) // df + 1,
+ int(self.seq_len / area_z)
+ )
+
+ # deduce target shape of the [latent video]
+ target_area_z = min(area_z, int(self.seq_len / of))
+ oh = round(np.sqrt(target_area_z * ratio))
+ ow = int(target_area_z / oh)
+ of = (of - 1) * df + 1
+ oh *= dh
+ ow *= dw
+
+ # sample frame ids
+ target_duration = duration
+ target_fps = of / target_duration
+ timestamps = np.linspace(0., target_duration, of)
+ frame_ids = np.argmax(np.logical_and(
+ timestamps[:, None] >= frame_timestamps[None, :, 0],
+ timestamps[:, None] <= frame_timestamps[None, :, 1]
+ ), axis=1).tolist()
+ # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
+ return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
+
+
+ def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
+ if self.keep_last:
+ return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
+ else:
+ return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
+
+ def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
+ return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
+
+ def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
+ return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
+
+ def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
+ rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
+ # read video
+ import decord
+ decord.bridge.set_bridge('torch')
+ readers = []
+ for data_k in data_key_batch:
+ reader = decord.VideoReader(data_k)
+ readers.append(reader)
+
+ fps = readers[0].get_avg_fps()
+ length = min([len(r) for r in readers])
+ frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
+ frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
+ h, w = readers[0].next().shape[:2]
+ frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
+
+ # preprocess video
+ videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
+ videos = [self._video_preprocess(video, oh, ow) for video in videos]
+ return *videos, frame_ids, (oh, ow), fps
+ # return videos if len(videos) > 1 else videos[0]
+
+
+def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
+ for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
+ if sub_src_video is None and sub_src_mask is None:
+ src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
+ src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
+ for i, ref_images in enumerate(src_ref_images):
+ if ref_images is not None:
+ for j, ref_img in enumerate(ref_images):
+ if ref_img is not None and ref_img.shape[-2:] != image_size:
+ canvas_height, canvas_width = image_size
+ ref_height, ref_width = ref_img.shape[-2:]
+ white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
+ scale = min(canvas_height / ref_height, canvas_width / ref_width)
+ new_height = int(ref_height * scale)
+ new_width = int(ref_width * scale)
+ resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
+ top = (canvas_height - new_height) // 2
+ left = (canvas_width - new_width) // 2
+ white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
+ src_ref_images[i][j] = white_canvas
+ return src_video, src_mask, src_ref_images
diff --git a/vace/models/wan/__init__.py b/vace/models/wan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c9f319186abd76a97bd448b6ceb57564e117c80
--- /dev/null
+++ b/vace/models/wan/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from . import modules
+from .wan_vace import WanVace
diff --git a/vace/models/wan/__pycache__/__init__.cpython-311.pyc b/vace/models/wan/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36e5b7bd26c4a6d3fee9b38652f21300293b1155
Binary files /dev/null and b/vace/models/wan/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vace/models/wan/__pycache__/wan_vace.cpython-311.pyc b/vace/models/wan/__pycache__/wan_vace.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95b528699128b5584f2e2351a7801e1b61521b54
Binary files /dev/null and b/vace/models/wan/__pycache__/wan_vace.cpython-311.pyc differ
diff --git a/vace/models/wan/configs/__init__.py b/vace/models/wan/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..46ebd0dd2a619b230a7f7fd6e36326be020c7882
--- /dev/null
+++ b/vace/models/wan/configs/__init__.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+
+os.environ['TOKENIZERS_PARALLELISM'] = 'false'
+
+from .wan_t2v_1_3B import t2v_1_3B
+from .wan_t2v_14B import t2v_14B
+
+WAN_CONFIGS = {
+ 'vace-1.3B': t2v_1_3B,
+ 'vace-14B': t2v_14B,
+}
+
+SIZE_CONFIGS = {
+ '720*1280': (720, 1280),
+ '1280*720': (1280, 720),
+ '480*832': (480, 832),
+ '832*480': (832, 480),
+ '1024*1024': (1024, 1024),
+ '720p': (1280, 720),
+ '480p': (480, 832)
+}
+
+MAX_AREA_CONFIGS = {
+ '720*1280': 720 * 1280,
+ '1280*720': 1280 * 720,
+ '480*832': 480 * 832,
+ '832*480': 832 * 480,
+ '720p': 1280 * 720,
+ '480p': 480 * 832
+}
+
+SUPPORTED_SIZES = {
+ 'vace-1.3B': ('480*832', '832*480', '480p'),
+ 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480', '480p', '720p')
+}
diff --git a/vace/models/wan/configs/__pycache__/__init__.cpython-311.pyc b/vace/models/wan/configs/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b7140198e5b3ae3f3db9401912d2d2742214e7f
Binary files /dev/null and b/vace/models/wan/configs/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vace/models/wan/configs/__pycache__/shared_config.cpython-311.pyc b/vace/models/wan/configs/__pycache__/shared_config.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c21b0fe4f8c9717b4ab76019edbb79049d21cb4
Binary files /dev/null and b/vace/models/wan/configs/__pycache__/shared_config.cpython-311.pyc differ
diff --git a/vace/models/wan/configs/__pycache__/wan_t2v_14B.cpython-311.pyc b/vace/models/wan/configs/__pycache__/wan_t2v_14B.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a93579703336219d5add35a27fb558aa28573c73
Binary files /dev/null and b/vace/models/wan/configs/__pycache__/wan_t2v_14B.cpython-311.pyc differ
diff --git a/vace/models/wan/configs/__pycache__/wan_t2v_1_3B.cpython-311.pyc b/vace/models/wan/configs/__pycache__/wan_t2v_1_3B.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cb7d6a10dd3636109d5705516e6973ebc64d77c
Binary files /dev/null and b/vace/models/wan/configs/__pycache__/wan_t2v_1_3B.cpython-311.pyc differ
diff --git a/vace/models/wan/configs/shared_config.py b/vace/models/wan/configs/shared_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..74e5cef2540f4ffb00dab185c93bd42b30b89989
--- /dev/null
+++ b/vace/models/wan/configs/shared_config.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+from easydict import EasyDict
+
+#------------------------ Wan shared config ------------------------#
+wan_shared_cfg = EasyDict()
+
+# t5
+wan_shared_cfg.t5_model = 'umt5_xxl'
+wan_shared_cfg.t5_dtype = torch.bfloat16
+wan_shared_cfg.text_len = 512
+
+# transformer
+wan_shared_cfg.param_dtype = torch.float16
+
+# inference
+wan_shared_cfg.num_train_timesteps = 1000
+wan_shared_cfg.sample_fps = 16
+wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
diff --git a/vace/models/wan/configs/wan_t2v_14B.py b/vace/models/wan/configs/wan_t2v_14B.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d0ee69dea796bfd6eccdedf4ec04835086227a6
--- /dev/null
+++ b/vace/models/wan/configs/wan_t2v_14B.py
@@ -0,0 +1,29 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+from easydict import EasyDict
+
+from .shared_config import wan_shared_cfg
+
+#------------------------ Wan T2V 14B ------------------------#
+
+t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
+t2v_14B.update(wan_shared_cfg)
+
+# t5
+t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
+t2v_14B.t5_tokenizer = 'google/umt5-xxl'
+
+# vae
+t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
+t2v_14B.vae_stride = (4, 8, 8)
+
+# transformer
+t2v_14B.patch_size = (1, 2, 2)
+t2v_14B.dim = 5120
+t2v_14B.ffn_dim = 13824
+t2v_14B.freq_dim = 256
+t2v_14B.num_heads = 40
+t2v_14B.num_layers = 40
+t2v_14B.window_size = (-1, -1)
+t2v_14B.qk_norm = True
+t2v_14B.cross_attn_norm = True
+t2v_14B.eps = 1e-6
diff --git a/vace/models/wan/configs/wan_t2v_1_3B.py b/vace/models/wan/configs/wan_t2v_1_3B.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd1464ec010e7bf2570e4375e2814a0943a189a
--- /dev/null
+++ b/vace/models/wan/configs/wan_t2v_1_3B.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from easydict import EasyDict
+
+from .shared_config import wan_shared_cfg
+
+#------------------------ Wan T2V 1.3B ------------------------#
+
+t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
+t2v_1_3B.update(wan_shared_cfg)
+
+# t5
+t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
+t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
+
+# vae
+t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
+t2v_1_3B.vae_stride = (4, 8, 8)
+
+# transformer
+t2v_1_3B.patch_size = (1, 2, 2)
+t2v_1_3B.dim = 1536
+t2v_1_3B.ffn_dim = 8960
+t2v_1_3B.freq_dim = 256
+t2v_1_3B.num_heads = 12
+t2v_1_3B.num_layers = 30
+t2v_1_3B.window_size = (-1, -1)
+t2v_1_3B.qk_norm = True
+t2v_1_3B.cross_attn_norm = True
+t2v_1_3B.eps = 1e-6
diff --git a/vace/models/wan/distributed/__init__.py b/vace/models/wan/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13a6b25acc146323b5d4769bf7ed6abc3b5d7d68
--- /dev/null
+++ b/vace/models/wan/distributed/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from .xdit_context_parallel import pad_freqs, rope_apply, usp_dit_forward_vace, usp_dit_forward, usp_attn_forward
\ No newline at end of file
diff --git a/vace/models/wan/distributed/xdit_context_parallel.py b/vace/models/wan/distributed/xdit_context_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..aacf47fa041966d857a3c68bed48e9363375802d
--- /dev/null
+++ b/vace/models/wan/distributed/xdit_context_parallel.py
@@ -0,0 +1,227 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+import torch.cuda.amp as amp
+from xfuser.core.distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ get_sp_group)
+from xfuser.core.long_ctx_attention import xFuserLongContextAttention
+
+from ..modules.model import sinusoidal_embedding_1d
+
+
+def pad_freqs(original_tensor, target_len):
+ seq_len, s1, s2 = original_tensor.shape
+ pad_size = target_len - seq_len
+ padding_tensor = torch.ones(
+ pad_size,
+ s1,
+ s2,
+ dtype=original_tensor.dtype,
+ device=original_tensor.device)
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
+ return padded_tensor
+
+
+@amp.autocast(enabled=False)
+def rope_apply(x, grid_sizes, freqs):
+ """
+ x: [B, L, N, C].
+ grid_sizes: [B, 3].
+ freqs: [M, C // 2].
+ """
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
+ # split freqs
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
+
+ # loop over samples
+ output = []
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
+ seq_len = f * h * w
+
+ # precompute multipliers
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
+ s, n, -1, 2))
+ freqs_i = torch.cat([
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ],
+ dim=-1).reshape(seq_len, 1, -1)
+
+ # apply rotary embedding
+ sp_size = get_sequence_parallel_world_size()
+ sp_rank = get_sequence_parallel_rank()
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
+ s_per_rank = s
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
+ s_per_rank), :, :]
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
+ x_i = torch.cat([x_i, x[i, s:]])
+
+ # append to collection
+ output.append(x_i)
+ return torch.stack(output).float()
+
+
+def usp_dit_forward_vace(
+ self,
+ x,
+ vace_context,
+ seq_len,
+ kwargs
+):
+ # embeddings
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
+ c = [u.flatten(2).transpose(1, 2) for u in c]
+ c = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
+ dim=1) for u in c
+ ])
+
+ # arguments
+ new_kwargs = dict(x=x)
+ new_kwargs.update(kwargs)
+
+ # Context Parallel
+ c = torch.chunk(
+ c, get_sequence_parallel_world_size(),
+ dim=1)[get_sequence_parallel_rank()]
+
+ for block in self.vace_blocks:
+ c = block(c, **new_kwargs)
+ hints = torch.unbind(c)[:-1]
+ return hints
+
+
+def usp_dit_forward(
+ self,
+ x,
+ t,
+ vace_context,
+ context,
+ seq_len,
+ vace_context_scale=1.0,
+ clip_fea=None,
+ y=None,
+):
+ """
+ x: A list of videos each with shape [C, T, H, W].
+ t: [B].
+ context: A list of text embeddings each with shape [L, C].
+ """
+ # params
+ device = self.patch_embedding.weight.device
+ if self.freqs.device != device:
+ self.freqs = self.freqs.to(device)
+
+ # if y is not None:
+ # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack(
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
+ assert seq_lens.max() <= seq_len
+ x = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
+ for u in x
+ ])
+
+ # time embeddings
+ with amp.autocast(dtype=torch.float32):
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
+
+ # context
+ context_lens = None
+ context = self.text_embedding(
+ torch.stack([
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
+ for u in context
+ ]))
+
+ # if clip_fea is not None:
+ # context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ # context = torch.concat([context_clip, context], dim=1)
+
+ # arguments
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=self.freqs,
+ context=context,
+ context_lens=context_lens)
+
+ # Context Parallel
+ x = torch.chunk(
+ x, get_sequence_parallel_world_size(),
+ dim=1)[get_sequence_parallel_rank()]
+
+ hints = self.forward_vace(x, vace_context, seq_len, kwargs)
+ kwargs['hints'] = hints
+ kwargs['context_scale'] = vace_context_scale
+
+ for block in self.blocks:
+ x = block(x, **kwargs)
+
+ # head
+ x = self.head(x, e)
+
+ # Context Parallel
+ x = get_sp_group().all_gather(x, dim=1)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return [u.float() for u in x]
+
+
+def usp_attn_forward(self,
+ x,
+ seq_lens,
+ grid_sizes,
+ freqs,
+ dtype=torch.bfloat16):
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+ half_dtypes = (torch.float16, torch.bfloat16)
+
+ def half(x):
+ return x if x.dtype in half_dtypes else x.to(dtype)
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ v = self.v(x).view(b, s, n, d)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+ q = rope_apply(q, grid_sizes, freqs)
+ k = rope_apply(k, grid_sizes, freqs)
+
+ # TODO: We should use unpaded q,k,v for attention.
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
+ # if k_lens is not None:
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
+
+ x = xFuserLongContextAttention()(
+ None,
+ query=half(q),
+ key=half(k),
+ value=half(v),
+ window_size=self.window_size)
+
+ # TODO: padding after attention.
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
diff --git a/vace/models/wan/modules/__init__.py b/vace/models/wan/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..307c3dd672ce42d271d0bff77a43c071aa32e271
--- /dev/null
+++ b/vace/models/wan/modules/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from .model import VaceWanAttentionBlock, BaseWanAttentionBlock, VaceWanModel
\ No newline at end of file
diff --git a/vace/models/wan/modules/__pycache__/__init__.cpython-311.pyc b/vace/models/wan/modules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32b3118f577a488e5f10f8805df2c06d90624bc8
Binary files /dev/null and b/vace/models/wan/modules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vace/models/wan/modules/__pycache__/mm_attention.cpython-311.pyc b/vace/models/wan/modules/__pycache__/mm_attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f7a8f2e7fb91c94bace996e2a6e29e9e47b8027
Binary files /dev/null and b/vace/models/wan/modules/__pycache__/mm_attention.cpython-311.pyc differ
diff --git a/vace/models/wan/modules/__pycache__/model.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f151dd6b1295f3fae31d4473b0e055c8fc722ae2
Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model.cpython-311.pyc differ
diff --git a/vace/models/wan/modules/__pycache__/model_mm.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model_mm.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..076e90fb3a6da535a99f30da7be29b566600f1ec
Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model_mm.cpython-311.pyc differ
diff --git a/vace/models/wan/modules/__pycache__/model_ruan.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model_ruan.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..616e207ce99e05d1b4676975a9353f025620880e
Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model_ruan.cpython-311.pyc differ
diff --git a/vace/models/wan/modules/__pycache__/model_tr.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model_tr.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82183bc19c0ae07232079ad125283a0b82d4f158
Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model_tr.cpython-311.pyc differ
diff --git a/vace/models/wan/modules/__pycache__/model_wan.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model_wan.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..104ca0acabd89f205e09b768bf99c65df0cfa1d5
Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model_wan.cpython-311.pyc differ
diff --git a/vace/models/wan/modules/mm_attention.py b/vace/models/wan/modules/mm_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..84f2e0bcb055bc8bb26ed024c5335a6399e30a7d
--- /dev/null
+++ b/vace/models/wan/modules/mm_attention.py
@@ -0,0 +1,747 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import math
+
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+
+from wan.modules.attention import flash_attention
+from wan.modules.model import WanSelfAttention, WanAttentionBlock
+
+__all__ = ['WanModel']
+
+T5_CONTEXT_TOKEN_NUMBER = 512
+FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2
+
+
+def sinusoidal_embedding_1d(dim, position):
+ # preprocess
+ assert dim % 2 == 0
+ half = dim // 2
+ position = position.type(torch.float64)
+
+ # calculation
+ sinusoid = torch.outer(
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x
+
+
+@amp.autocast(enabled=False)
+def rope_params(max_seq_len, dim, theta=10000):
+ assert dim % 2 == 0
+ freqs = torch.outer(
+ torch.arange(max_seq_len),
+ 1.0 / torch.pow(theta,
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs
+
+
+@amp.autocast(enabled=False)
+def rope_apply(x, grid_sizes, freqs):
+ n, c = x.size(2), x.size(3) // 2
+
+ # split freqs
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
+
+ # loop over samples
+ output = []
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
+ seq_len = f * h * w
+
+ # precompute multipliers
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
+ seq_len, n, -1, 2))
+ freqs_i = torch.cat([
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
+ ],
+ dim=-1).reshape(seq_len, 1, -1)
+
+ # apply rotary embedding
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
+ x_i = torch.cat([x_i, x[i, seq_len:]])
+
+ # append to collection
+ output.append(x_i)
+ return torch.stack(output).float()
+
+
+class WanRMSNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return self._norm(x.float()).type_as(x) * self.weight
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
+
+
+class WanLayerNorm(nn.LayerNorm):
+
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(self, x):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+# class MMSelfAttention(nn.Module):
+
+# def __init__(self,
+# dim,
+# num_heads,
+# window_size=(-1, -1),
+# qk_norm=True,
+# eps=1e-6):
+# assert dim % num_heads == 0
+# super().__init__()
+# self.dim = dim
+# self.num_heads = num_heads
+# self.head_dim = dim // num_heads
+# self.window_size = window_size
+# self.qk_norm = qk_norm
+# self.eps = eps
+
+# # layers
+# self.q = nn.Linear(dim, dim)
+# self.k = nn.Linear(dim, dim)
+# self.v = nn.Linear(dim, dim)
+# self.o = nn.Linear(dim, dim)
+# self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+# self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+# def forward(self, x, seq_lens, grid_sizes, grid_sizes_ref, freqs, len1):
+# r"""
+# Args:
+# x(Tensor): Shape [B, L, num_heads, C / num_heads]
+# seq_lens(Tensor): Shape [B]
+# grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+# freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+# """
+# b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+# # query, key, value function
+# def qkv_fn(x):
+# q = self.norm_q(self.q(x)).view(b, s, n, d)
+# k = self.norm_k(self.k(x)).view(b, s, n, d)
+# v = self.v(x).view(b, s, n, d)
+# return q, k, v
+
+# q, k, v = qkv_fn(x)
+# q[:, :len1] = rope_apply(q[:, :len1], grid_sizes, freqs)
+# k[:, :len1] = rope_apply(k[:, :len1], grid_sizes, freqs)
+# q[:, len1:] = rope_apply(q[:, len1:], grid_sizes_ref, freqs)
+# k[:, len1:] = rope_apply(k[:, len1:], grid_sizes_ref, freqs)
+
+# x = flash_attention(
+# q=q,
+# k=k,
+# v=v,
+# k_lens=seq_lens,
+# window_size=self.window_size)
+
+# # output
+# x = x.flatten(2)
+# x = self.o(x)
+# return x
+
+class MMSelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.eps = eps
+
+ # layers
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.o = nn.Linear(dim, dim)
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ self.q_ref = nn.Linear(dim, dim)
+ self.k_ref = nn.Linear(dim, dim)
+ self.v_ref = nn.Linear(dim, dim)
+ self.norm_q_ref = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+ self.norm_k_ref = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, x_ref, seq_lens, grid_sizes, grid_sizes_ref, freqs, freqs_ref, len1):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
+ seq_lens(Tensor): Shape [B]
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+ b_ref, s_ref = x_ref.shape[:2]
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ v = self.v(x).view(b, s, n, d)
+ return q, k, v
+ def qkv_fn_ref(x_ref):
+ q_ref = self.norm_q_ref(self.q_ref(x_ref)).view(b_ref, s_ref, n, d)
+ k_ref = self.norm_k_ref(self.k_ref(x_ref)).view(b_ref, s_ref, n, d)
+ v_ref = self.v_ref(x_ref).view(b_ref, s_ref, n, d)
+ return q_ref, k_ref, v_ref
+
+ q, k, v = qkv_fn(x)
+ q_ref, k_ref, v_ref = qkv_fn_ref(x_ref)
+ q = torch.cat([q_ref, q], dim=1)
+ k = torch.cat([k_ref, k], dim=1)
+ v = torch.cat([v_ref, v], dim=1)
+
+ q = rope_apply(q, grid_sizes, freqs)
+ k = rope_apply(k, grid_sizes, freqs)
+ x = flash_attention(
+ q=q,
+ k=k,
+ v=v,
+ k_lens=seq_lens,
+ window_size=self.window_size)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+class WanT2VCrossAttention(WanSelfAttention):
+
+ def forward(self, x, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ v = self.v(context).view(b, -1, n, d)
+
+ # compute attention
+ x = flash_attention(q, k, v, k_lens=context_lens)
+
+ # output
+ x = x.flatten(2)
+ x = self.o(x)
+ return x
+
+
+class WanI2VCrossAttention(WanSelfAttention):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6):
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
+
+ self.k_img = nn.Linear(dim, dim)
+ self.v_img = nn.Linear(dim, dim)
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
+
+ def forward(self, x, context, context_lens):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ context_lens(Tensor): Shape [B]
+ """
+ image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
+ context_img = context[:, :image_context_length]
+ context = context[:, image_context_length:]
+ b, n, d = x.size(0), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
+ v = self.v(context).view(b, -1, n, d)
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
+ v_img = self.v_img(context_img).view(b, -1, n, d)
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
+ # compute attention
+ x = flash_attention(q, k, v, k_lens=context_lens)
+
+ # output
+ x = x.flatten(2)
+ img_x = img_x.flatten(2)
+ x = x + img_x
+ x = self.o(x)
+ return x
+
+
+WAN_CROSSATTENTION_CLASSES = {
+ 't2v_cross_attn': WanT2VCrossAttention,
+ 'i2v_cross_attn': WanI2VCrossAttention,
+}
+
+
+class MMAttentionBlock(nn.Module):
+
+ def __init__(self,
+ cross_attn_type,
+ dim,
+ dim_ref,
+ ffn_dim,
+ ffn_dim_ref,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6):
+ super().__init__()
+ self.dim = dim
+ self.dim_ref = dim_ref
+ self.ffn_dim = ffn_dim
+ self.ffn_dim_ref = ffn_dim_ref
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # layers
+ self.norm1 = WanLayerNorm(dim, eps)
+ self.self_attn = MMSelfAttention(dim, num_heads, window_size, qk_norm,
+ eps)
+ self.norm3 = WanLayerNorm(
+ dim, eps,
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
+ num_heads,
+ (-1, -1),
+ qk_norm,
+ eps)
+ self.norm2 = WanLayerNorm(dim, eps)
+ self.ffn = nn.Sequential(
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
+ nn.Linear(ffn_dim, dim))
+
+ self.norm1_ref = WanLayerNorm(dim_ref, eps)
+ self.norm3_ref = WanLayerNorm(
+ dim_ref, eps,
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
+ self.norm2_ref = WanLayerNorm(dim_ref, eps)
+ self.ffn_ref = nn.Sequential(
+ nn.Linear(dim_ref, ffn_dim_ref), nn.GELU(approximate='tanh'),
+ nn.Linear(ffn_dim_ref, dim_ref))
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+ self.modulation_ref = nn.Parameter(torch.randn(1, 6, dim_ref) / dim_ref**0.5)
+
+ def forward(
+ self,
+ x,
+ x_ref,
+ e,
+ e_ref,
+ seq_lens,
+ seq_lens_ref,
+ grid_sizes,
+ grid_sizes_ref,
+ freqs,
+ freqs_ref,
+ context,
+ context_lens,
+ ):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ e(Tensor): Shape [B, 6, C]
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ assert e.dtype == torch.float32 and e_ref.dtype == torch.float32
+ with amp.autocast(dtype=torch.float32):
+ e = (self.modulation + e).chunk(6, dim=1)
+ e_ref = (self.modulation_ref + e_ref).chunk(6, dim=1)
+ assert e[0].dtype == torch.float32 and e_ref[0].dtype == torch.float32
+
+ len1, len2 = x.shape[1], x_ref.shape[1]
+
+ # self-attention
+ y = self.self_attn(
+ self.norm1(x).float() * (1 + e[1]) + e[0], self.norm1_ref(x_ref).float() * (1 + e_ref[1]) + e_ref[0],
+ seq_lens, grid_sizes, grid_sizes_ref, freqs, freqs_ref, len1)
+ y_ref, y = y[:, :len2], y[:, len2:]
+ with amp.autocast(dtype=torch.float32):
+ x = x + y * e[2]
+ x_ref = x_ref + y_ref * e_ref[2]
+
+ # cross-attention & ffn function
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
+ with amp.autocast(dtype=torch.float32):
+ x = x + y * e[5]
+
+ y_ref = self.ffn_ref(self.norm2_ref(x_ref).float() * (1 + e_ref[4]) + e_ref[3])
+ with amp.autocast(dtype=torch.float32):
+ x_ref = x_ref + y_ref * e_ref[5]
+ return x, x_ref
+
+
+class Head(nn.Module):
+
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
+ super().__init__()
+ self.dim = dim
+ self.out_dim = out_dim
+ self.patch_size = patch_size
+ self.eps = eps
+
+ # layers
+ out_dim = math.prod(patch_size) * out_dim
+ self.norm = WanLayerNorm(dim, eps)
+ self.head = nn.Linear(dim, out_dim)
+
+ # modulation
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+
+ def forward(self, x, e):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ e(Tensor): Shape [B, C]
+ """
+ assert e.dtype == torch.float32
+ with amp.autocast(dtype=torch.float32):
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
+ return x
+
+
+class MLPProj(torch.nn.Module):
+
+ def __init__(self, in_dim, out_dim, flf_pos_emb=False):
+ super().__init__()
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
+ torch.nn.LayerNorm(out_dim))
+ if flf_pos_emb: # NOTE: we only use this for `flf2v`
+ self.emb_pos = nn.Parameter(
+ torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
+
+ def forward(self, image_embeds):
+ if hasattr(self, 'emb_pos'):
+ bs, n, d = image_embeds.shape
+ image_embeds = image_embeds.view(-1, 2 * n, d)
+ image_embeds = image_embeds + self.emb_pos
+ clip_extra_context_tokens = self.proj(image_embeds)
+ return clip_extra_context_tokens
+
+
+class MMModel(ModelMixin, ConfigMixin):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ ignore_for_config = [
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
+ ]
+ _no_split_modules = ['WanAttentionBlock']
+
+ @register_to_config
+ def __init__(self,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ dim_ref=2048,
+ ffn_dim=8192,
+ ffn_dim_ref=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6):
+ r"""
+ Initialize the diffusion model backbone.
+
+ Args:
+ model_type (`str`, *optional*, defaults to 't2v'):
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
+ text_len (`int`, *optional*, defaults to 512):
+ Fixed length for text embeddings
+ in_dim (`int`, *optional*, defaults to 16):
+ Input video channels (C_in)
+ dim (`int`, *optional*, defaults to 2048):
+ Hidden dimension of the transformer
+ ffn_dim (`int`, *optional*, defaults to 8192):
+ Intermediate dimension in feed-forward network
+ freq_dim (`int`, *optional*, defaults to 256):
+ Dimension for sinusoidal time embeddings
+ text_dim (`int`, *optional*, defaults to 4096):
+ Input dimension for text embeddings
+ out_dim (`int`, *optional*, defaults to 16):
+ Output video channels (C_out)
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads
+ num_layers (`int`, *optional*, defaults to 32):
+ Number of transformer blocks
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
+ Window size for local attention (-1 indicates global attention)
+ qk_norm (`bool`, *optional*, defaults to True):
+ Enable query/key normalization
+ cross_attn_norm (`bool`, *optional*, defaults to False):
+ Enable cross-attention normalization
+ eps (`float`, *optional*, defaults to 1e-6):
+ Epsilon value for normalization layers
+ """
+
+ super().__init__()
+
+ assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
+ self.model_type = model_type
+
+ self.patch_size = patch_size
+ self.text_len = text_len
+ self.in_dim = in_dim
+ self.dim = dim
+ self.dim_ref = dim_ref
+ self.ffn_dim = ffn_dim
+ self.ffn_dim_ref = ffn_dim_ref
+ self.freq_dim = freq_dim
+ self.text_dim = text_dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # embeddings
+ self.patch_embedding = nn.Conv3d(
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
+ self.text_embedding = nn.Sequential(
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
+ nn.Linear(dim, dim))
+
+ self.time_embedding = nn.Sequential(
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
+ self.time_embedding_ref = nn.Sequential(
+ nn.Linear(freq_dim, dim_ref), nn.SiLU(), nn.Linear(dim_ref, dim_ref))
+ self.time_projection_ref = nn.Sequential(nn.SiLU(), nn.Linear(dim_ref, dim_ref * 6))
+
+ # blocks
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
+
+ self.blocks = nn.ModuleList([
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm, eps)
+ for _ in range(num_layers)
+ ])
+
+ # head
+ self.head = Head(dim, out_dim, patch_size, eps)
+
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
+ d = dim // num_heads
+ self.freqs = torch.cat([
+ rope_params(1024, d - 4 * (d // 6)),
+ rope_params(1024, 2 * (d // 6)),
+ rope_params(1024, 2 * (d // 6))
+ ],
+ dim=1)
+ self.freqs_ref = torch.cat([
+ rope_params(1024, d - 4 * (d // 6), 5000),
+ rope_params(1024, 2 * (d // 6), 5000),
+ rope_params(1024, 2 * (d // 6), 5000)
+ ],
+ dim=1)
+
+ if model_type == 'i2v' or model_type == 'flf2v':
+ self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
+
+ # initialize weights
+ self.init_weights()
+
+ def forward(
+ self,
+ x,
+ t,
+ context,
+ seq_len,
+ clip_fea=None,
+ y=None,
+ ):
+ r"""
+ Forward pass through the diffusion model
+
+ Args:
+ x (List[Tensor]):
+ List of input video tensors, each with shape [C_in, F, H, W]
+ t (Tensor):
+ Diffusion timesteps tensor of shape [B]
+ context (List[Tensor]):
+ List of text embeddings each with shape [L, C]
+ seq_len (`int`):
+ Maximum sequence length for positional encoding
+ clip_fea (Tensor, *optional*):
+ CLIP image features for image-to-video mode or first-last-frame-to-video mode
+ y (List[Tensor], *optional*):
+ Conditional video inputs for image-to-video mode, same shape as x
+
+ Returns:
+ List[Tensor]:
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
+ """
+ if self.model_type == 'i2v' or self.model_type == 'flf2v':
+ assert clip_fea is not None and y is not None
+ # params
+ device = self.patch_embedding.weight.device
+ if self.freqs.device != device:
+ self.freqs = self.freqs.to(device)
+
+ if y is not None:
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack(
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
+ assert seq_lens.max() <= seq_len
+ x = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
+ dim=1) for u in x
+ ])
+
+ # time embeddings
+ with amp.autocast(dtype=torch.float32):
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
+
+ # context
+ context_lens = None
+ context = self.text_embedding(
+ torch.stack([
+ torch.cat(
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
+ for u in context
+ ]))
+
+ if clip_fea is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim
+ context = torch.concat([context_clip, context], dim=1)
+
+ # arguments
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=self.freqs,
+ context=context,
+ context_lens=context_lens)
+
+ for block in self.blocks:
+ x = block(x, **kwargs)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return [u.float() for u in x]
+
+ def unpatchify(self, x, grid_sizes):
+ r"""
+ Reconstruct video tensors from patch embeddings.
+
+ Args:
+ x (List[Tensor]):
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
+ grid_sizes (Tensor):
+ Original spatial-temporal grid dimensions before patching,
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
+
+ Returns:
+ List[Tensor]:
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
+ """
+
+ c = self.out_dim
+ out = []
+ for u, v in zip(x, grid_sizes.tolist()):
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
+ out.append(u)
+ return out
+
+ def init_weights(self):
+ r"""
+ Initialize model parameters using Xavier initialization.
+ """
+
+ # basic init
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ # init embeddings
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
+ for m in self.text_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+ for m in self.time_embedding.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, std=.02)
+
+ # init output layer
+ nn.init.zeros_(self.head.head.weight)
diff --git a/vace/models/wan/modules/model.py b/vace/models/wan/modules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..31efed7b9fa90a5061263f30f74dd3da79f2a127
--- /dev/null
+++ b/vace/models/wan/modules/model.py
@@ -0,0 +1,250 @@
+# -*- coding: utf-8 -*-
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+from diffusers.configuration_utils import register_to_config
+from wan.modules.model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
+import torch.utils.checkpoint
+
+def create_custom_forward(module):
+ def custom_forward(*inputs, **kwargs):
+ return module(*inputs, **kwargs)
+ return custom_forward
+
+
+def gradient_checkpoint_forward(
+ model,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ *args,
+ **kwargs,
+):
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ model_output = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(model),
+ *args,
+ **kwargs,
+ use_reentrant=False,
+ )
+ elif use_gradient_checkpointing:
+ model_output = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(model),
+ *args,
+ **kwargs,
+ use_reentrant=False,
+ )
+ else:
+ model_output = model(*args, **kwargs)
+ return model_output
+
+
+class VaceWanAttentionBlock(WanAttentionBlock):
+ def __init__(
+ self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ block_id=0
+ ):
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = nn.Linear(self.dim, self.dim)
+ nn.init.zeros_(self.before_proj.weight)
+ nn.init.zeros_(self.before_proj.bias)
+ self.after_proj = nn.Linear(self.dim, self.dim)
+ nn.init.zeros_(self.after_proj.weight)
+ nn.init.zeros_(self.after_proj.bias)
+
+ def forward(self, c, x, **kwargs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ all_c = []
+ else:
+ all_c = list(torch.unbind(c))
+ c = all_c.pop(-1)
+ c = super().forward(c, **kwargs)
+ c_skip = self.after_proj(c)
+ all_c += [c_skip, c]
+ c = torch.stack(all_c)
+ return c
+
+
+class BaseWanAttentionBlock(WanAttentionBlock):
+ def __init__(
+ self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ block_id=None
+ ):
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
+ self.block_id = block_id
+
+ def forward(self, x, hints, context_scale=1.0, **kwargs):
+ x = super().forward(x, **kwargs)
+ if self.block_id is not None:
+ x = x + hints[self.block_id] * context_scale
+ return x
+
+
+class VaceWanModel(WanModel):
+ @register_to_config
+ def __init__(self,
+ vace_layers=None,
+ vace_in_dim=None,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6):
+ model_type = "t2v" # TODO: Hard code for both preview and official versions.
+ super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
+ num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
+
+ self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
+ self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
+
+ assert 0 in self.vace_layers
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
+
+ # blocks
+ self.blocks = nn.ModuleList([
+ BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
+ self.cross_attn_norm, self.eps,
+ block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
+ for i in range(self.num_layers)
+ ])
+
+ # vace blocks
+ self.vace_blocks = nn.ModuleList([
+ VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
+ self.cross_attn_norm, self.eps, block_id=i)
+ for i in self.vace_layers
+ ])
+
+ # vace patch embeddings
+ self.vace_patch_embedding = nn.Conv3d(
+ self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
+ )
+
+ def forward_vace(
+ self,
+ x,
+ vace_context,
+ seq_len,
+ kwargs,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ ):
+ # embeddings
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
+ c = [u.flatten(2).transpose(1, 2) for u in c]
+ c = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
+ dim=1) for u in c
+ ])
+
+ new_kwargs = dict(x=x)
+ new_kwargs.update(kwargs)
+
+ for block in self.vace_blocks:
+ c = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ c,
+ **new_kwargs,
+ )
+ hints = torch.unbind(c)[:-1]
+ return hints
+
+ def forward(
+ self,
+ x,
+ t,
+ vace_context,
+ context,
+ seq_len,
+ vace_context_scale=1.0,
+ clip_fea=None,
+ y=None,
+ use_gradient_checkpointing=True,
+ use_gradient_checkpointing_offload=False,
+ ):
+ device = self.patch_embedding.weight.device
+ if self.freqs.device != device:
+ self.freqs = self.freqs.to(device)
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack(
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
+ assert seq_lens.max() <= seq_len
+ x = torch.cat([
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
+ dim=1) for u in x
+ ])
+
+ # time embeddings
+ with amp.autocast(dtype=torch.float32):
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+
+ # text embedding
+ context_lens = None
+ context = self.text_embedding(context)
+
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=self.freqs,
+ context=context,
+ context_lens=context_lens,
+ )
+
+ hints = self.forward_vace(
+ x, vace_context, seq_len, kwargs,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+ kwargs['hints'] = hints
+ kwargs['context_scale'] = vace_context_scale
+
+ for block in self.blocks:
+ x = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ x,
+ **kwargs,
+ )
+
+ x = self.head(x, e)
+ x = self.unpatchify(x, grid_sizes)
+ return [u.float() for u in x]
diff --git a/vace/models/wan/modules/model_mm.py b/vace/models/wan/modules/model_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc531eafbb31456ef909ad1b4db2d4b2dc5477de
--- /dev/null
+++ b/vace/models/wan/modules/model_mm.py
@@ -0,0 +1,301 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+from diffusers.configuration_utils import register_to_config
+from wan.modules.model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
+from .mm_attention import MMModel, MMAttentionBlock
+import torch.utils.checkpoint
+
+def create_custom_forward(module):
+ def custom_forward(*inputs, **kwargs):
+ return module(*inputs, **kwargs)
+ return custom_forward
+
+
+def gradient_checkpoint_forward(
+ model,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ *args,
+ **kwargs,
+):
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ model_output = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(model),
+ *args,
+ **kwargs,
+ use_reentrant=False,
+ )
+ elif use_gradient_checkpointing:
+ model_output = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(model),
+ *args,
+ **kwargs,
+ use_reentrant=False,
+ )
+ else:
+ model_output = model(*args, **kwargs)
+ return model_output
+
+
+class VaceWanAttentionBlock(MMAttentionBlock):
+ def __init__(
+ self,
+ cross_attn_type,
+ dim,
+ dim_ref,
+ ffn_dim,
+ ffn_dim_ref,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ block_id=0
+ ):
+ super().__init__(cross_attn_type, dim, dim_ref, ffn_dim, ffn_dim_ref, num_heads, window_size, qk_norm, cross_attn_norm, eps)
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = nn.Linear(self.dim, self.dim)
+ nn.init.zeros_(self.before_proj.weight)
+ nn.init.zeros_(self.before_proj.bias)
+ self.before_proj_ref = nn.Linear(self.dim, self.dim)
+ nn.init.zeros_(self.before_proj_ref.weight)
+ nn.init.zeros_(self.before_proj_ref.bias)
+ self.after_proj = nn.Linear(self.dim, self.dim)
+ nn.init.zeros_(self.after_proj.weight)
+ nn.init.zeros_(self.after_proj.bias)
+
+ def forward(self, c, c_ref, x, x_ref, **kwargs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ c_ref = self.before_proj_ref(c_ref) + x_ref
+ c, c_ref = super().forward(c, c_ref, **kwargs)
+ c_skip = self.after_proj(torch.cat([c_ref, c], dim=1))
+ return c, c_ref, c_skip
+
+
+class BaseWanAttentionBlock(WanAttentionBlock):
+ def __init__(
+ self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ block_id=None
+ ):
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
+ self.block_id = block_id
+
+ def forward(self, x, hints, context_scale=1.0, **kwargs):
+ x = super().forward(x, **kwargs)
+ if self.block_id is not None:
+ x = x + hints[self.block_id] * context_scale
+ return x
+
+
+class VaceMMModel(MMModel):
+ @register_to_config
+ def __init__(self,
+ vace_layers=None,
+ vace_in_dim=None,
+ ref_in_dim=None,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ dim_ref=2048,
+ ffn_dim=8192,
+ ffn_dim_ref=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6):
+ model_type = "t2v" # TODO: Hard code for both preview and official versions.
+ super().__init__(model_type, patch_size, text_len, in_dim, dim, dim_ref, ffn_dim, ffn_dim_ref, freq_dim, text_dim, out_dim,
+ num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
+
+ self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
+ self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
+ self.ref_in_dim = ref_in_dim
+
+ assert 0 in self.vace_layers
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
+
+ # blocks
+ self.blocks = nn.ModuleList([
+ BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
+ self.cross_attn_norm, self.eps,
+ block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
+ for i in range(self.num_layers)
+ ])
+
+ # vace blocks
+ self.vace_blocks = nn.ModuleList([
+ VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.dim_ref, self.ffn_dim, self.ffn_dim_ref, self.num_heads, self.window_size, self.qk_norm,
+ self.cross_attn_norm, self.eps, block_id=i)
+ for i in self.vace_layers
+ ])
+
+ # vace patch embeddings
+ self.vace_patch_embedding = nn.Conv3d(
+ self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
+ )
+ self.ref_patch_embedding = nn.Conv3d(
+ self.ref_in_dim, self.dim_ref, kernel_size=self.patch_size, stride=self.patch_size
+ )
+
+ def forward_vace(
+ self,
+ x,
+ x_ref,
+ vace_context,
+ ref_context,
+ seq_len,
+ seq_len_ref,
+ kwargs,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=False,
+ ):
+ # embeddings
+ c = self.vace_patch_embedding(vace_context)
+ c = c.flatten(2).transpose(1, 2)
+ # c = torch.cat([
+ # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
+ # dim=1) for u in c
+ # ])
+
+ c_ref = self.ref_patch_embedding(ref_context)
+ c_ref = c_ref.flatten(2).transpose(1, 2)
+ # c_ref = torch.cat([
+ # torch.cat([u, u.new_zeros(1, seq_len_ref - u.size(1), u.size(2))],
+ # dim=1) for u in c_ref
+ # ])
+
+ new_kwargs = dict(x=x, x_ref=x_ref)
+ new_kwargs.update(kwargs)
+ hints = []
+ for block in self.vace_blocks:
+ c, c_ref, c_skip = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ c,
+ c_ref,
+ **new_kwargs,
+ )
+ if c_skip is not None:
+ hints.append(c_skip)
+ return hints
+
+ def forward(
+ self,
+ x,
+ t,
+ vace_context,
+ ref_context,
+ context,
+ seq_len,
+ seq_len_ref,
+ vace_context_scale=1.0,
+ clip_fea=None,
+ y=None,
+ use_gradient_checkpointing=True,
+ use_gradient_checkpointing_offload=False,
+ ):
+ device = self.patch_embedding.weight.device
+ if self.freqs.device != device:
+ self.freqs = self.freqs.to(device)
+ self.freqs_ref = self.freqs_ref.to(device)
+
+ # embeddings
+
+ grid_sizes = torch.tensor([x.shape[2]/self.patch_size[0], x.shape[3]/self.patch_size[1], x.shape[4]/self.patch_size[2]]).long().unsqueeze(0).repeat(x.shape[0], 1)
+ seq_lens = torch.tensor(
+ (x.shape[2]//self.patch_size[0])
+ * (x.shape[3]//self.patch_size[1])
+ * (x.shape[4]//self.patch_size[2])
+ ).repeat(x.shape[0]).long()
+ x_ref = self.patch_embedding(x[:, :, :1, :, :])
+ x_ref = x_ref.flatten(2).transpose(1, 2)
+ x_vid = self.patch_embedding(x[:, :, 1:, :, :])
+ x_vid = x_vid.flatten(2).transpose(1, 2)
+ # print(x.dtype)
+ grid_sizes_ref = torch.tensor([ref_context.shape[2]/self.patch_size[0], ref_context.shape[3]/self.patch_size[1], ref_context.shape[4]/self.patch_size[2]]).long().unsqueeze(0).repeat(grid_sizes.shape[0], 1)
+
+ seq_lens_ref = torch.tensor(
+ (ref_context.shape[2]//self.patch_size[0])
+ * (ref_context.shape[3]//self.patch_size[1])
+ * (ref_context.shape[4]//self.patch_size[2])
+ ).repeat(x.shape[0]).long()
+ assert seq_lens.max() <= seq_len
+ # x = torch.cat([
+ # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
+ # dim=1) for u in x
+ # ])
+
+ # time embeddings
+ with amp.autocast(dtype=torch.float32):
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+ e_ref = self.time_embedding_ref(
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0_ref = self.time_projection_ref(e_ref).unflatten(1, (6, self.dim))
+
+ # text embedding
+ context_lens = None
+ context = self.text_embedding(context)
+
+ kwargs = dict(
+ e=e0,
+ e_ref=e0_ref,
+ seq_lens=seq_lens,
+ seq_lens_ref=seq_lens_ref,
+ grid_sizes=grid_sizes,
+ grid_sizes_ref=grid_sizes_ref,
+ freqs=self.freqs,
+ freqs_ref=self.freqs_ref,
+ context=context,
+ context_lens=context_lens,
+ )
+ # ========== 支路 VACE ==========
+ hints = self.forward_vace(
+ x_vid, x_ref, vace_context, ref_context, seq_len, seq_len_ref, kwargs,
+ use_gradient_checkpointing=False,
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
+ )
+ kwargs['hints'] = hints
+ kwargs['context_scale'] = vace_context_scale
+ kwargs.pop("e_ref")
+ kwargs.pop("seq_lens_ref")
+ kwargs.pop("grid_sizes_ref")
+ kwargs.pop("freqs_ref")
+
+ # ========== 主路 BLOCKS ==========
+ x = torch.cat([x_ref, x_vid], dim=1)
+ for block in self.blocks:
+ x = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ x,
+ **kwargs,
+ )
+
+ x = self.head(x, e)
+ x = self.unpatchify(x, grid_sizes)
+ return [u.float() for u in x]
diff --git a/vace/models/wan/modules/model_tr.py b/vace/models/wan/modules/model_tr.py
new file mode 100644
index 0000000000000000000000000000000000000000..abbfc02a84c47e37fa339970db8e291ec10ec692
--- /dev/null
+++ b/vace/models/wan/modules/model_tr.py
@@ -0,0 +1,371 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+from diffusers.configuration_utils import register_to_config
+from wan.modules.model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
+
+
+def create_custom_forward(module):
+ def custom_forward(*inputs, **kwargs):
+ return module(*inputs, **kwargs)
+
+ return custom_forward
+
+
+def gradient_checkpoint_forward(
+ model,
+ use_gradient_checkpointing,
+ use_gradient_checkpointing_offload,
+ *args,
+ **kwargs,
+):
+ if use_gradient_checkpointing_offload:
+ with torch.autograd.graph.save_on_cpu():
+ model_output = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(model),
+ *args,
+ **kwargs,
+ use_reentrant=False,
+ )
+ elif use_gradient_checkpointing:
+ model_output = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(model),
+ *args,
+ **kwargs,
+ use_reentrant=False,
+ )
+ else:
+ model_output = model(*args, **kwargs)
+ return model_output
+
+
+class VaceWanAttentionBlock(WanAttentionBlock):
+ def __init__(
+ self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ block_id=0,
+ ):
+ super().__init__(
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size,
+ qk_norm,
+ cross_attn_norm,
+ eps,
+ )
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = nn.Linear(self.dim, self.dim)
+ nn.init.zeros_(self.before_proj.weight)
+ nn.init.zeros_(self.before_proj.bias)
+ self.after_proj = nn.Linear(self.dim, self.dim)
+ nn.init.zeros_(self.after_proj.weight)
+ nn.init.zeros_(self.after_proj.bias)
+
+ def forward(self, c, x, **kwargs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ all_c = []
+ else:
+ all_c = list(torch.unbind(c))
+ c = all_c.pop(-1)
+ c = super().forward(c, **kwargs)
+ c_skip = self.after_proj(c)
+ all_c += [c_skip, c]
+ c = torch.stack(all_c)
+ return c
+
+
+class BaseWanAttentionBlock(WanAttentionBlock):
+ def __init__(
+ self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ block_id=None,
+ ):
+ super().__init__(
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size,
+ qk_norm,
+ cross_attn_norm,
+ eps,
+ )
+ self.block_id = block_id
+
+ def forward(self, x, hints, context_scale=1.0, **kwargs):
+ x = super().forward(x, **kwargs)
+ if self.block_id is not None:
+ x = x + hints[self.block_id] * context_scale
+ return x
+
+
+class VaceWanModel(WanModel):
+ @register_to_config
+ def __init__(
+ self,
+ vace_layers=None,
+ vace_in_dim=None,
+ model_type="t2v",
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ activation_checkpointing=False,
+ eps=1e-6,
+ ):
+ super().__init__(
+ model_type,
+ patch_size,
+ text_len,
+ in_dim,
+ dim,
+ ffn_dim,
+ freq_dim,
+ text_dim,
+ out_dim,
+ num_heads,
+ num_layers,
+ window_size,
+ qk_norm,
+ cross_attn_norm,
+ eps,
+ )
+
+ self.vace_layers = (
+ [i for i in range(0, self.num_layers, 2)]
+ if vace_layers is None
+ else vace_layers
+ )
+ self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
+
+ assert 0 in self.vace_layers
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
+
+ # blocks
+ self.blocks = nn.ModuleList(
+ [
+ BaseWanAttentionBlock(
+ "t2v_cross_attn",
+ self.dim,
+ self.ffn_dim,
+ self.num_heads,
+ self.window_size,
+ self.qk_norm,
+ self.cross_attn_norm,
+ self.eps,
+ block_id=(
+ self.vace_layers_mapping[i] if i in self.vace_layers else None
+ ),
+ )
+ for i in range(self.num_layers)
+ ]
+ )
+
+ # vace blocks
+ self.vace_blocks = nn.ModuleList(
+ [
+ VaceWanAttentionBlock(
+ "t2v_cross_attn",
+ self.dim,
+ self.ffn_dim,
+ self.num_heads,
+ self.window_size,
+ self.qk_norm,
+ self.cross_attn_norm,
+ self.eps,
+ block_id=i,
+ )
+ for i in self.vace_layers
+ ]
+ )
+
+ # vace patch embeddings
+ self.vace_patch_embedding = nn.Conv3d(
+ self.vace_in_dim,
+ self.dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ )
+
+ def forward_vace(self, x, vace_context, seq_len, kwargs):
+ # embeddings
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
+ c = [u.flatten(2).transpose(1, 2) for u in c]
+ c = torch.cat(
+ [
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
+ for u in c
+ ]
+ )
+
+ # arguments
+ new_kwargs = dict(x=x)
+ new_kwargs.update(kwargs)
+
+ for block in self.vace_blocks:
+ c = block(c, **new_kwargs)
+ # for block in self.vace_blocks:
+ # c = gradient_checkpoint_forward(
+ # block,s
+ # True,
+ # False,
+ # c,
+ # **new_kwargs,
+ # )
+
+ hints = torch.unbind(c)[:-1]
+ return hints
+
+ def forward(
+ self,
+ x,
+ t,
+ vace_context,
+ context,
+ seq_len,
+ vace_context_scale=1.0,
+ clip_fea=None,
+ use_gradient_checkpointing=False,
+ y=None,
+ ):
+ r"""
+ Forward pass through the diffusion model
+
+ Args:
+ x (List[Tensor]):
+ List of input video tensors, each with shape [C_in, F, H, W]
+ t (Tensor):
+ Diffusion timesteps tensor of shape [B]
+ context (List[Tensor]):
+ List of text embeddings each with shape [L, C]
+ seq_len (`int`):
+ Maximum sequence length for positional encoding
+ clip_fea (Tensor, *optional*):
+ CLIP image features for image-to-video mode
+ y (List[Tensor], *optional*):
+ Conditional video inputs for image-to-video mode, same shape as x
+
+ Returns:
+ List[Tensor]:
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
+ """
+ # if self.model_type == 'i2v':
+ # assert clip_fea is not None and y is not None
+ # params
+ device = self.patch_embedding.weight.device
+ if self.freqs.device != device:
+ self.freqs = self.freqs.to(device)
+
+ # if y is not None:
+ # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
+
+ # embeddings
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ grid_sizes = torch.stack(
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
+ )
+ x = [u.flatten(2).transpose(1, 2) for u in x]
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
+ # print(seq_lens)
+ # print(seq_len)
+ assert seq_lens.max() <= seq_len
+ x = torch.cat(
+ [
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
+ for u in x
+ ]
+ )
+
+ # time embeddings
+ with amp.autocast(dtype=torch.float32):
+ e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float())
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
+
+ # context
+ context_lens = None
+ context = self.text_embedding(
+ torch.stack([
+ torch.cat(
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
+ for u in context
+ ]))
+
+ # context = self.text_embedding(context)
+
+ # if clip_fea is not None:
+ # context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ # context = torch.concat([context_clip, context], dim=1)
+
+ # arguments
+ kwargs = dict(
+ e=e0,
+ seq_lens=seq_lens,
+ grid_sizes=grid_sizes,
+ freqs=self.freqs,
+ context=context,
+ context_lens=context_lens,
+ )
+
+ hints = self.forward_vace(x, vace_context, seq_len, kwargs)
+ kwargs["hints"] = hints
+ kwargs["context_scale"] = vace_context_scale
+
+ # for block in self.blocks:
+
+ # if self.activation_checkpointing:
+ # ### Some bug here with deepspeed zero
+ # # print("use activation checkpointing")
+ # # x = x.requires_grad_()
+ # x = torch.utils.checkpoint.checkpoint(
+ # lambda inp: block(inp, **kwargs), x, use_reentrant=False
+ # )
+ # else:
+ # x = block(x, **kwargs)
+
+ for block in self.blocks:
+ x = gradient_checkpoint_forward(
+ block,
+ use_gradient_checkpointing,
+ False,
+ x,
+ **kwargs,
+ )
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ # return [u.float() for u in x]
+ # print(x.shape)
+ return x
diff --git a/vace/models/wan/wan_vace.py b/vace/models/wan/wan_vace.py
new file mode 100644
index 0000000000000000000000000000000000000000..d388c5073c28d13fbea987a2c3cdad8ae703dfe8
--- /dev/null
+++ b/vace/models/wan/wan_vace.py
@@ -0,0 +1,719 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import sys
+import gc
+import math
+import time
+import random
+import types
+import logging
+import traceback
+from contextlib import contextmanager
+from functools import partial
+
+from PIL import Image
+import torchvision.transforms.functional as TF
+import torch
+import torch.nn.functional as F
+import torch.cuda.amp as amp
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from tqdm import tqdm
+
+from wan.text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
+ get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
+from .modules.model import VaceWanModel
+from ..utils.preprocessor import VaceVideoProcessor
+
+
+class WanVace(WanT2V):
+ def __init__(
+ self,
+ config,
+ checkpoint_dir,
+ device_id=0,
+ rank=0,
+ t5_fsdp=False,
+ dit_fsdp=False,
+ use_usp=False,
+ t5_cpu=False,
+ ):
+ r"""
+ Initializes the Wan text-to-video generation model components.
+
+ Args:
+ config (EasyDict):
+ Object containing model parameters initialized from config.py
+ checkpoint_dir (`str`):
+ Path to directory containing model checkpoints
+ device_id (`int`, *optional*, defaults to 0):
+ Id of target GPU device
+ rank (`int`, *optional*, defaults to 0):
+ Process rank for distributed training
+ t5_fsdp (`bool`, *optional*, defaults to False):
+ Enable FSDP sharding for T5 model
+ dit_fsdp (`bool`, *optional*, defaults to False):
+ Enable FSDP sharding for DiT model
+ use_usp (`bool`, *optional*, defaults to False):
+ Enable distribution strategy of USP.
+ t5_cpu (`bool`, *optional*, defaults to False):
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
+ """
+ self.device = torch.device(f"cuda:{device_id}")
+ self.config = config
+ self.rank = rank
+ self.t5_cpu = t5_cpu
+
+ self.num_train_timesteps = config.num_train_timesteps
+ self.param_dtype = config.param_dtype
+
+ shard_fn = partial(shard_model, device_id=device_id)
+ self.text_encoder = T5EncoderModel(
+ text_len=config.text_len,
+ dtype=config.t5_dtype,
+ device=torch.device('cpu'),
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
+ shard_fn=shard_fn if t5_fsdp else None)
+
+ self.vae_stride = config.vae_stride
+ self.patch_size = config.patch_size
+ self.vae = WanVAE(
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
+ device=self.device)
+
+ logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
+ self.model = VaceWanModel.from_pretrained(checkpoint_dir)
+ self.model.eval().requires_grad_(False)
+
+ if use_usp:
+ from xfuser.core.distributed import \
+ get_sequence_parallel_world_size
+
+ from .distributed.xdit_context_parallel import (usp_attn_forward,
+ usp_dit_forward,
+ usp_dit_forward_vace)
+ for block in self.model.blocks:
+ block.self_attn.forward = types.MethodType(
+ usp_attn_forward, block.self_attn)
+ for block in self.model.vace_blocks:
+ block.self_attn.forward = types.MethodType(
+ usp_attn_forward, block.self_attn)
+ self.model.forward = types.MethodType(usp_dit_forward, self.model)
+ self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
+ self.sp_size = get_sequence_parallel_world_size()
+ else:
+ self.sp_size = 1
+
+ if dist.is_initialized():
+ dist.barrier()
+ if dit_fsdp:
+ self.model = shard_fn(self.model)
+ else:
+ self.model.to(self.device)
+
+ self.sample_neg_prompt = config.sample_neg_prompt
+
+ self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
+ min_area=480 * 832,
+ max_area=480 * 832,
+ min_fps=self.config.sample_fps,
+ max_fps=self.config.sample_fps,
+ zero_start=True,
+ seq_len=32760,
+ keep_last=True)
+
+ def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
+ vae = self.vae if vae is None else vae
+ if ref_images is None:
+ ref_images = [None] * len(frames)
+ else:
+ assert len(frames) == len(ref_images)
+
+ if masks is None:
+ latents = vae.encode(frames)
+ else:
+ masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
+ inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
+ reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
+ inactive = vae.encode(inactive)
+ reactive = vae.encode(reactive)
+ latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
+
+ cat_latents = []
+ for latent, refs in zip(latents, ref_images):
+ if refs is not None:
+ if masks is None:
+ ref_latent = vae.encode(refs)
+ else:
+ ref_latent = vae.encode(refs)
+ ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
+ assert all([x.shape[1] == 1 for x in ref_latent])
+ latent = torch.cat([*ref_latent, latent], dim=1)
+ cat_latents.append(latent)
+ return cat_latents
+
+ def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
+ vae_stride = self.vae_stride if vae_stride is None else vae_stride
+ if ref_images is None:
+ ref_images = [None] * len(masks)
+ else:
+ assert len(masks) == len(ref_images)
+
+ result_masks = []
+ for mask, refs in zip(masks, ref_images):
+ c, depth, height, width = mask.shape
+ new_depth = int((depth + 3) // vae_stride[0])
+ height = 2 * (int(height) // (vae_stride[1] * 2))
+ width = 2 * (int(width) // (vae_stride[2] * 2))
+
+ # reshape
+ mask = mask[0, :, :, :]
+ mask = mask.view(
+ depth, height, vae_stride[1], width, vae_stride[1]
+ ) # depth, height, 8, width, 8
+ mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
+ mask = mask.reshape(
+ vae_stride[1] * vae_stride[2], depth, height, width
+ ) # 8*8, depth, height, width
+
+ # interpolation
+ mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
+
+ if refs is not None:
+ length = len(refs)
+ mask_pad = torch.zeros_like(mask[:, :length, :, :])
+ mask = torch.cat((mask_pad, mask), dim=1)
+ result_masks.append(mask)
+ return result_masks
+
+ def vace_latent(self, z, m):
+ return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
+
+ def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
+ area = image_size[0] * image_size[1]
+ self.vid_proc.set_area(area)
+ if area == 720*1280:
+ self.vid_proc.set_seq_len(75600)
+ elif area == 480*832:
+ self.vid_proc.set_seq_len(32760)
+ else:
+ raise NotImplementedError(f'image_size {image_size} is not supported')
+
+ image_size = (image_size[1], image_size[0])
+ image_sizes = []
+ for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
+ if sub_src_mask is not None and sub_src_video is not None:
+ src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
+ src_video[i] = src_video[i].to(device)
+ src_mask[i] = src_mask[i].to(device)
+ src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
+ image_sizes.append(src_video[i].shape[2:])
+ elif sub_src_video is None:
+ src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
+ src_mask[i] = torch.ones_like(src_video[i], device=device)
+ image_sizes.append(image_size)
+ else:
+ src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
+ src_video[i] = src_video[i].to(device)
+ src_mask[i] = torch.ones_like(src_video[i], device=device)
+ image_sizes.append(src_video[i].shape[2:])
+
+ for i, ref_images in enumerate(src_ref_images):
+ if ref_images is not None:
+ image_size = image_sizes[i]
+ for j, ref_img in enumerate(ref_images):
+ if ref_img is not None:
+ ref_img = Image.open(ref_img).convert("RGB")
+ ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
+ if ref_img.shape[-2:] != image_size:
+ canvas_height, canvas_width = image_size
+ ref_height, ref_width = ref_img.shape[-2:]
+ white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
+ scale = min(canvas_height / ref_height, canvas_width / ref_width)
+ new_height = int(ref_height * scale)
+ new_width = int(ref_width * scale)
+ resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
+ top = (canvas_height - new_height) // 2
+ left = (canvas_width - new_width) // 2
+ white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
+ ref_img = white_canvas
+ src_ref_images[i][j] = ref_img.to(device)
+ return src_video, src_mask, src_ref_images
+
+ def decode_latent(self, zs, ref_images=None, vae=None):
+ vae = self.vae if vae is None else vae
+ if ref_images is None:
+ ref_images = [None] * len(zs)
+ else:
+ assert len(zs) == len(ref_images)
+
+ trimed_zs = []
+ for z, refs in zip(zs, ref_images):
+ if refs is not None:
+ z = z[:, len(refs):, :, :]
+ trimed_zs.append(z)
+
+ return vae.decode(trimed_zs)
+
+
+
+ def generate(self,
+ input_prompt,
+ input_frames,
+ input_masks,
+ input_ref_images,
+ size=(1280, 720),
+ frame_num=81,
+ context_scale=1.0,
+ shift=5.0,
+ sample_solver='unipc',
+ sampling_steps=50,
+ guide_scale=5.0,
+ n_prompt="",
+ seed=-1,
+ offload_model=True):
+ r"""
+ Generates video frames from text prompt using diffusion process.
+
+ Args:
+ input_prompt (`str`):
+ Text prompt for content generation
+ size (tupele[`int`], *optional*, defaults to (1280,720)):
+ Controls video resolution, (width,height).
+ frame_num (`int`, *optional*, defaults to 81):
+ How many frames to sample from a video. The number should be 4n+1
+ shift (`float`, *optional*, defaults to 5.0):
+ Noise schedule shift parameter. Affects temporal dynamics
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
+ Solver used to sample the video.
+ sampling_steps (`int`, *optional*, defaults to 40):
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
+ guide_scale (`float`, *optional*, defaults 5.0):
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
+ n_prompt (`str`, *optional*, defaults to ""):
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
+ seed (`int`, *optional*, defaults to -1):
+ Random seed for noise generation. If -1, use random seed.
+ offload_model (`bool`, *optional*, defaults to True):
+ If True, offloads models to CPU during generation to save VRAM
+
+ Returns:
+ torch.Tensor:
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
+ - C: Color channels (3 for RGB)
+ - N: Number of frames (81)
+ - H: Frame height (from size)
+ - W: Frame width from size)
+ """
+ # preprocess
+ # F = frame_num
+ # target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
+ # size[1] // self.vae_stride[1],
+ # size[0] // self.vae_stride[2])
+ #
+ # seq_len = math.ceil((target_shape[2] * target_shape[3]) /
+ # (self.patch_size[1] * self.patch_size[2]) *
+ # target_shape[1] / self.sp_size) * self.sp_size
+
+ if n_prompt == "":
+ n_prompt = self.sample_neg_prompt
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
+ seed_g = torch.Generator(device=self.device)
+ seed_g.manual_seed(seed)
+
+ if not self.t5_cpu:
+ self.text_encoder.model.to(self.device)
+ context = self.text_encoder([input_prompt], self.device)
+ context_null = self.text_encoder([n_prompt], self.device)
+ if offload_model:
+ self.text_encoder.model.cpu()
+ else:
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
+ context = [t.to(self.device) for t in context]
+ context_null = [t.to(self.device) for t in context_null]
+
+ # vace context encode
+ z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
+ m0 = self.vace_encode_masks(input_masks, input_ref_images)
+ z = self.vace_latent(z0, m0)
+
+ target_shape = list(z0[0].shape)
+ target_shape[0] = int(target_shape[0] / 2)
+ noise = [
+ torch.randn(
+ target_shape[0],
+ target_shape[1],
+ target_shape[2],
+ target_shape[3],
+ dtype=torch.float32,
+ device=self.device,
+ generator=seed_g)
+ ]
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
+ (self.patch_size[1] * self.patch_size[2]) *
+ target_shape[1] / self.sp_size) * self.sp_size
+
+ @contextmanager
+ def noop_no_sync():
+ yield
+
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
+
+ # evaluation mode
+ with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
+
+ if sample_solver == 'unipc':
+ sample_scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=self.num_train_timesteps,
+ shift=1,
+ use_dynamic_shifting=False)
+ sample_scheduler.set_timesteps(
+ sampling_steps, device=self.device, shift=shift)
+ timesteps = sample_scheduler.timesteps
+ elif sample_solver == 'dpm++':
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
+ num_train_timesteps=self.num_train_timesteps,
+ shift=1,
+ use_dynamic_shifting=False)
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
+ timesteps, _ = retrieve_timesteps(
+ sample_scheduler,
+ device=self.device,
+ sigmas=sampling_sigmas)
+ else:
+ raise NotImplementedError("Unsupported solver.")
+
+ # sample videos
+ latents = noise
+
+ arg_c = {'context': context, 'seq_len': seq_len}
+ arg_null = {'context': context_null, 'seq_len': seq_len}
+
+ for _, t in enumerate(tqdm(timesteps)):
+ latent_model_input = latents
+ timestep = [t]
+
+ timestep = torch.stack(timestep)
+
+ self.model.to(self.device)
+ noise_pred_cond = self.model(
+ latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
+ noise_pred_uncond = self.model(
+ latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
+
+ noise_pred = noise_pred_uncond + guide_scale * (
+ noise_pred_cond - noise_pred_uncond)
+
+ temp_x0 = sample_scheduler.step(
+ noise_pred.unsqueeze(0),
+ t,
+ latents[0].unsqueeze(0),
+ return_dict=False,
+ generator=seed_g)[0]
+ latents = [temp_x0.squeeze(0)]
+
+ x0 = latents
+ if offload_model:
+ self.model.cpu()
+ torch.cuda.empty_cache()
+ if self.rank == 0:
+ videos = self.decode_latent(x0, input_ref_images)
+
+ del noise, latents
+ del sample_scheduler
+ if offload_model:
+ gc.collect()
+ torch.cuda.synchronize()
+ if dist.is_initialized():
+ dist.barrier()
+
+ return videos[0] if self.rank == 0 else None
+
+
+class WanVaceMP(WanVace):
+ def __init__(
+ self,
+ config,
+ checkpoint_dir,
+ use_usp=False,
+ ulysses_size=None,
+ ring_size=None
+ ):
+ self.config = config
+ self.checkpoint_dir = checkpoint_dir
+ self.use_usp = use_usp
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '12345'
+ os.environ['RANK'] = '0'
+ os.environ['WORLD_SIZE'] = '1'
+ self.in_q_list = None
+ self.out_q = None
+ self.inference_pids = None
+ self.ulysses_size = ulysses_size
+ self.ring_size = ring_size
+ self.dynamic_load()
+
+ self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
+ self.vid_proc = VaceVideoProcessor(
+ downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
+ min_area=720 * 1280,
+ max_area=720 * 1280,
+ min_fps=config.sample_fps,
+ max_fps=config.sample_fps,
+ zero_start=True,
+ seq_len=75600,
+ keep_last=True)
+
+
+ def dynamic_load(self):
+ if hasattr(self, 'inference_pids') and self.inference_pids is not None:
+ return
+ gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
+ pmi_rank = int(os.environ['RANK'])
+ pmi_world_size = int(os.environ['WORLD_SIZE'])
+ in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
+ out_q = torch.multiprocessing.Manager().Queue()
+ initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
+ context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
+ all_initialized = False
+ while not all_initialized:
+ all_initialized = all(event.is_set() for event in initialized_events)
+ if not all_initialized:
+ time.sleep(0.1)
+ print('Inference model is initialized', flush=True)
+ self.in_q_list = in_q_list
+ self.out_q = out_q
+ self.inference_pids = context.pids()
+ self.initialized_events = initialized_events
+
+ def transfer_data_to_cuda(self, data, device):
+ if data is None:
+ return None
+ else:
+ if isinstance(data, torch.Tensor):
+ data = data.to(device)
+ elif isinstance(data, list):
+ data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
+ elif isinstance(data, dict):
+ data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
+ return data
+
+ def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
+ try:
+ world_size = pmi_world_size * gpu_infer
+ rank = pmi_rank * gpu_infer + gpu
+ print("world_size", world_size, "rank", rank, flush=True)
+
+ torch.cuda.set_device(gpu)
+ dist.init_process_group(
+ backend='nccl',
+ init_method='env://',
+ rank=rank,
+ world_size=world_size
+ )
+
+ from xfuser.core.distributed import (initialize_model_parallel,
+ init_distributed_environment)
+ init_distributed_environment(
+ rank=dist.get_rank(), world_size=dist.get_world_size())
+
+ initialize_model_parallel(
+ sequence_parallel_degree=dist.get_world_size(),
+ ring_degree=self.ring_size or 1,
+ ulysses_degree=self.ulysses_size or 1
+ )
+
+ num_train_timesteps = self.config.num_train_timesteps
+ param_dtype = self.config.param_dtype
+ shard_fn = partial(shard_model, device_id=gpu)
+ text_encoder = T5EncoderModel(
+ text_len=self.config.text_len,
+ dtype=self.config.t5_dtype,
+ device=torch.device('cpu'),
+ checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
+ tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
+ shard_fn=shard_fn if True else None)
+ text_encoder.model.to(gpu)
+ vae_stride = self.config.vae_stride
+ patch_size = self.config.patch_size
+ vae = WanVAE(
+ vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
+ device=gpu)
+ logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
+ model = VaceWanModel.from_pretrained(self.checkpoint_dir)
+ model.eval().requires_grad_(False)
+
+ if self.use_usp:
+ from xfuser.core.distributed import get_sequence_parallel_world_size
+ from .distributed.xdit_context_parallel import (usp_attn_forward,
+ usp_dit_forward,
+ usp_dit_forward_vace)
+ for block in model.blocks:
+ block.self_attn.forward = types.MethodType(
+ usp_attn_forward, block.self_attn)
+ for block in model.vace_blocks:
+ block.self_attn.forward = types.MethodType(
+ usp_attn_forward, block.self_attn)
+ model.forward = types.MethodType(usp_dit_forward, model)
+ model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
+ sp_size = get_sequence_parallel_world_size()
+ else:
+ sp_size = 1
+
+ dist.barrier()
+ model = shard_fn(model)
+ sample_neg_prompt = self.config.sample_neg_prompt
+
+ torch.cuda.empty_cache()
+ event = initialized_events[gpu]
+ in_q = in_q_list[gpu]
+ event.set()
+
+ while True:
+ item = in_q.get()
+ input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \
+ shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
+ input_frames = self.transfer_data_to_cuda(input_frames, gpu)
+ input_masks = self.transfer_data_to_cuda(input_masks, gpu)
+ input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
+
+ if n_prompt == "":
+ n_prompt = sample_neg_prompt
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
+ seed_g = torch.Generator(device=gpu)
+ seed_g.manual_seed(seed)
+
+ context = text_encoder([input_prompt], gpu)
+ context_null = text_encoder([n_prompt], gpu)
+
+ # vace context encode
+ z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
+ m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
+ z = self.vace_latent(z0, m0)
+
+ target_shape = list(z0[0].shape)
+ target_shape[0] = int(target_shape[0] / 2)
+ noise = [
+ torch.randn(
+ target_shape[0],
+ target_shape[1],
+ target_shape[2],
+ target_shape[3],
+ dtype=torch.float32,
+ device=gpu,
+ generator=seed_g)
+ ]
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
+ (patch_size[1] * patch_size[2]) *
+ target_shape[1] / sp_size) * sp_size
+
+ @contextmanager
+ def noop_no_sync():
+ yield
+
+ no_sync = getattr(model, 'no_sync', noop_no_sync)
+
+ # evaluation mode
+ with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
+
+ if sample_solver == 'unipc':
+ sample_scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=num_train_timesteps,
+ shift=1,
+ use_dynamic_shifting=False)
+ sample_scheduler.set_timesteps(
+ sampling_steps, device=gpu, shift=shift)
+ timesteps = sample_scheduler.timesteps
+ elif sample_solver == 'dpm++':
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
+ num_train_timesteps=num_train_timesteps,
+ shift=1,
+ use_dynamic_shifting=False)
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
+ timesteps, _ = retrieve_timesteps(
+ sample_scheduler,
+ device=gpu,
+ sigmas=sampling_sigmas)
+ else:
+ raise NotImplementedError("Unsupported solver.")
+
+ # sample videos
+ latents = noise
+
+ arg_c = {'context': context, 'seq_len': seq_len}
+ arg_null = {'context': context_null, 'seq_len': seq_len}
+
+ for _, t in enumerate(tqdm(timesteps)):
+ latent_model_input = latents
+ timestep = [t]
+
+ timestep = torch.stack(timestep)
+
+ model.to(gpu)
+ noise_pred_cond = model(
+ latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
+ 0]
+ noise_pred_uncond = model(
+ latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
+ **arg_null)[0]
+
+ noise_pred = noise_pred_uncond + guide_scale * (
+ noise_pred_cond - noise_pred_uncond)
+
+ temp_x0 = sample_scheduler.step(
+ noise_pred.unsqueeze(0),
+ t,
+ latents[0].unsqueeze(0),
+ return_dict=False,
+ generator=seed_g)[0]
+ latents = [temp_x0.squeeze(0)]
+
+ torch.cuda.empty_cache()
+ x0 = latents
+ if rank == 0:
+ videos = self.decode_latent(x0, input_ref_images, vae=vae)
+
+ del noise, latents
+ del sample_scheduler
+ if offload_model:
+ gc.collect()
+ torch.cuda.synchronize()
+ if dist.is_initialized():
+ dist.barrier()
+
+ if rank == 0:
+ out_q.put(videos[0].cpu())
+
+ except Exception as e:
+ trace_info = traceback.format_exc()
+ print(trace_info, flush=True)
+ print(e, flush=True)
+
+
+
+ def generate(self,
+ input_prompt,
+ input_frames,
+ input_masks,
+ input_ref_images,
+ size=(1280, 720),
+ frame_num=81,
+ context_scale=1.0,
+ shift=5.0,
+ sample_solver='unipc',
+ sampling_steps=50,
+ guide_scale=5.0,
+ n_prompt="",
+ seed=-1,
+ offload_model=True):
+
+ input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
+ shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
+ for in_q in self.in_q_list:
+ in_q.put(input_data)
+ value_output = self.out_q.get()
+
+ return value_output
diff --git a/vae.py b/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c42f6aa659099fdf28621edd64ea3b86ace4b0
--- /dev/null
+++ b/vae.py
@@ -0,0 +1,707 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import logging
+
+import torch
+# import torch.cuda.amp as amp
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+__all__ = [
+ 'WanVAE',
+]
+
+CACHE_T = 2
+
+
+class CausalConv3d(nn.Conv3d):
+ """
+ Causal 3d convolusion.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
+ self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+
+ return super().forward(x)
+
+
+# class CausalConv3dNew(nn.Conv3d):
+# """
+# Causal 3d convolusion.
+# """
+
+# def __init__(self, *args, **kwargs):
+# super().__init__(*args, **kwargs)
+# self._padding = (self.padding[2], self.padding[2], self.padding[1],
+# self.padding[1], 2 * self.padding[0], 0)
+# self.padding = (0, 0, 0)
+
+# def forward(self, x, cache_x=None):
+# padding = list(self._padding)
+# if cache_x is not None and self._padding[4] > 0:
+# cache_x = cache_x.to(x.device)
+# x = torch.cat([cache_x, x], dim=2)
+# padding[4] -= cache_x.shape[2]
+# # x = F.pad(x, padding)
+
+# # return super().forward(x)
+
+# return F.conv3d(x,weight=self.weight,
+# bias=self.bias,
+# stride=self.stride,
+# dilation=self.dilation,
+# padding=padding)
+
+
+
+class RMS_norm(nn.Module):
+
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
+
+ def forward(self, x):
+ return F.normalize(
+ x, dim=(1 if self.channel_first else
+ -1)) * self.scale * self.gamma + self.bias
+
+
+class Upsample(nn.Upsample):
+
+ def forward(self, x):
+ """
+ Fix bfloat16 support for nearest neighbor interpolation.
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class Resample(nn.Module):
+
+ def __init__(self, dim, mode):
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
+ 'downsample3d')
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == 'upsample2d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
+ elif mode == 'upsample3d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
+ self.time_conv = CausalConv3d(
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == 'downsample2d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == 'downsample3d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = CausalConv3d(
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == 'upsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = 'Rep'
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] != 'Rep':
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] == 'Rep':
+ cache_x = torch.cat([
+ torch.zeros_like(cache_x).to(cache_x.device),
+ cache_x
+ ],
+ dim=2)
+ if feat_cache[idx] == 'Rep':
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
+ 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.resample(x)
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
+
+ if self.mode == 'downsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -1:, :, :].clone()
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
+ # # cache last frame of last two chunk
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.time_conv(
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+ def init_weight(self, conv):
+ conv_weight = conv.weight
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ one_matrix = torch.eye(c1, c2)
+ init_matrix = one_matrix
+ nn.init.zeros_(conv_weight)
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def init_weight2(self, conv):
+ conv_weight = conv.weight.data
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ init_matrix = torch.eye(c1 // 2, c2)
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, in_dim, out_dim, dropout=0.0):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # layers
+ self.residual = nn.Sequential(
+ RMS_norm(in_dim, images=False), nn.SiLU(),
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
+ if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ h = self.shortcut(x)
+ for layer in self.residual:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ Causal self-attention with a single head.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = RMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.proj.weight)
+
+ def forward(self, x):
+ identity = x
+ b, c, t, h, w = x.size()
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.norm(x)
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
+ -1).permute(0, 1, 3,
+ 2).contiguous().chunk(
+ 3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ )
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
+
+ # output
+ x = self.proj(x)
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
+ return x + identity
+
+
+class Encoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ downsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ downsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = 'downsample3d' if temperal_downsample[
+ i] else 'downsample2d'
+ downsamples.append(Resample(out_dim, mode=mode))
+ scale /= 2.0
+ self.downsamples = nn.Sequential(*downsamples)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
+ ResidualBlock(out_dim, out_dim, dropout))
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## downsamples
+ for layer in self.downsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+class Decoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2**(len(dim_mult) - 2)
+
+ # init block
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
+ ResidualBlock(dims[0], dims[0], dropout))
+
+ # upsample blocks
+ upsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i == 1 or i == 2 or i == 3:
+ in_dim = in_dim // 2
+ for _ in range(num_res_blocks + 1):
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ upsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # upsample block
+ if i != len(dim_mult) - 1:
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
+ upsamples.append(Resample(out_dim, mode=mode))
+ scale *= 2.0
+ self.upsamples = nn.Sequential(*upsamples)
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, 3, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## upsamples
+ for layer in self.upsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+def count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, CausalConv3d):
+ count += 1
+ return count
+
+
+class WanVAE_(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ # modules
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_downsample, dropout)
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_upsample, dropout)
+
+ def forward(self, x):
+ mu, log_var = self.encode(x)
+ z = self.reparameterize(mu, log_var)
+ x_recon = self.decode(z)
+ return x_recon, mu, log_var
+
+ def encode(self, x, scale):
+ self.clear_cache()
+ ## cache
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(
+ x[:, :, :1, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ out = torch.cat([out, out_], 2)
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
+ if isinstance(scale[0], torch.Tensor):
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
+ 1, self.z_dim, 1, 1, 1)
+ else:
+ mu = (mu - scale[0]) * scale[1]
+ self.clear_cache()
+ return mu
+
+ def decode(self, z, scale):
+ self.clear_cache()
+ # z: [b,c,t,h,w]
+ if isinstance(scale[0], torch.Tensor):
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
+ 1, self.z_dim, 1, 1, 1)
+ else:
+ z = z / scale[1] + scale[0]
+ iter_ = z.shape[2]
+ x = self.conv2(z)
+ for i in range(iter_):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+ self.clear_cache()
+ return out
+
+ def reparameterize(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample(self, imgs, deterministic=False):
+ mu, log_var = self.encode(imgs)
+ if deterministic:
+ return mu
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
+ return mu + std * torch.randn_like(std)
+
+ def clear_cache(self):
+ self._conv_num = count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ #cache encode
+ self._enc_conv_num = count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+
+def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
+ """
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
+ """
+ # params
+ cfg = dict(
+ dim=96,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[False, True, True],
+ dropout=0.0)
+ cfg.update(**kwargs)
+
+ # init model
+ with torch.device('meta'):
+ model = WanVAE_(**cfg)
+
+ # load checkpoint
+ logging.info(f'loading {pretrained_path}')
+ model.load_state_dict(
+ torch.load(pretrained_path, map_location=device), assign=True)
+
+ return model
+
+
+class WanVAE(nn.Module):
+
+ def __init__(self,
+ z_dim=16,
+ vae_pth='cache/vae_step_411000.pth',
+ dtype=torch.float,
+ device="cuda"):
+ super().__init__()
+ self.dtype = dtype
+ self.device = device
+
+ mean = [
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
+ ]
+ std = [
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
+ ]
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
+ self.std = torch.tensor(std, dtype=dtype, device=device)
+ self.scale = [self.mean, 1.0 / self.std]
+
+ # init model
+ self.model = _video_vae(
+ pretrained_path=vae_pth,
+ z_dim=z_dim,
+ ).eval().requires_grad_(False).to(device)
+
+ def encode(self, videos):
+ """
+ videos: A list of videos each with shape [C, T, H, W].
+ """
+ with torch.autocast(device_type='cuda',dtype=self.dtype):
+ return self.model.encode(videos, self.scale)
+
+
+ def decode(self, zs):
+ with torch.autocast(device_type='cuda',dtype=self.dtype):
+ return self.model.decode(zs,self.scale).float().clamp_(-1, 1)
+
+
+ # def encode(self, videos):
+ # """
+ # videos: A list of videos each with shape [C, T, H, W].
+ # """
+ # with torch.autocast(device_type='cuda',dtype=self.dtype):
+ # return [
+ # self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
+ # for u in videos
+ # ]
+
+ # def decode(self, zs):
+ # with torch.autocast(device_type='cuda',dtype=self.dtype):
+ # return [
+ # self.model.decode(u.unsqueeze(0),
+ # self.scale).float().clamp_(-1, 1).squeeze(0)
+ # for u in zs
+ # ]
+