diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4d1194d7dff69574e64c37896ed6629cb5a770 --- /dev/null +++ b/app.py @@ -0,0 +1,747 @@ +import os +import time +import random + +import gradio as gr +import cv2 +import numpy as np +from PIL import Image + +os.makedirs("./sam2/SAM2-Video-Predictor/checkpoints/", exist_ok=True) + +from huggingface_hub import snapshot_download + +def download_sam2(): + snapshot_download( + repo_id="facebook/sam2-hiera-large", + local_dir="./sam2/SAM2-Video-Predictor/checkpoints/", + ) + print("Download sam2 completed") + +def download_refacade(): + snapshot_download( + repo_id="fishze/Refacade", + local_dir="./models/", + ) + print("Download refacade completed") + + +# download_sam2() + +import torch +import torch.nn.functional as F +from decord import VideoReader, cpu +from moviepy.editor import ImageSequenceClip +from sam2.build_sam import build_sam2, build_sam2_video_predictor +from sam2.sam2_image_predictor import SAM2ImagePredictor +import spaces +from pipeline import RefacadePipeline +from vace.models.wan.modules.model_mm import VaceMMModel +from vace.models.wan.modules.model_tr import VaceWanModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from wan.text2video import FlowUniPCMultistepScheduler +from diffusers.utils import export_to_video, load_image, load_video +from vae import WanVAE + +COLOR_PALETTE = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (255, 128, 0), + (128, 0, 255), + (0, 128, 255), + (128, 255, 0), +] + +video_length = 201 +W = 1024 +H = W +device = "cuda" + +def get_pipe_image_and_video_predictor(): + vae = WanVAE( + vae_pth="./models/vae/Wan2.1_VAE.pth", + dtype=torch.float16, + ) + + pipe_device = "cuda" + + texture_remover = VaceWanModel.from_config( + "./models/texture_remover/texture_remover.json" + ) + ckpt = torch.load( + "./models/texture_remover/texture_remover.pth", + map_location="cpu", + ) + texture_remover.load_state_dict(ckpt) + texture_remover = texture_remover.to(dtype=torch.float16, device=pipe_device) + + model = VaceMMModel.from_config( + "./models/refacade/refacade.json" + ) + ckpt = torch.load( + "./models/refacade/refacade.pth", + map_location="cpu", + ) + model.load_state_dict(ckpt) + model = model.to(dtype=torch.float16, device=pipe_device) + + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=1, + ) + pipe = RefacadePipeline( + vae=vae, + transformer=model, + texture_remover=texture_remover, + scheduler=sample_scheduler, + ) + pipe.to(pipe_device) + + sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt" + config = "sam2_hiera_l.yaml" + + video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device="cuda") + model_sam = build_sam2(config, sam2_checkpoint, device="cuda") + model_sam.image_size = 1024 + image_predictor = SAM2ImagePredictor(sam_model=model_sam) + + return pipe, image_predictor, video_predictor + + +def get_video_info(video_path, video_state): + video_state["input_points"] = [] + video_state["scaled_points"] = [] + video_state["input_labels"] = [] + video_state["frame_idx"] = 0 + + vr = VideoReader(video_path, ctx=cpu(0)) + first_frame = vr[0].asnumpy() + del vr + + if first_frame.shape[0] > first_frame.shape[1]: + W_ = W + H_ = int(W_ * first_frame.shape[0] / first_frame.shape[1]) + else: + H_ = H + W_ = int(H_ * first_frame.shape[1] / first_frame.shape[0]) + + first_frame = cv2.resize(first_frame, (W_, H_)) + video_state["origin_images"] = np.expand_dims(first_frame, axis=0) + video_state["inference_state"] = None + video_state["video_path"] = video_path + video_state["masks"] = None + video_state["painted_images"] = None + image = Image.fromarray(first_frame) + return image + + +def segment_frame(evt: gr.SelectData, label, video_state): + if video_state["origin_images"] is None: + return None + x, y = evt.index + new_point = [x, y] + label_value = 1 if label == "Positive" else 0 + + video_state["input_points"].append(new_point) + video_state["input_labels"].append(label_value) + height, width = video_state["origin_images"][0].shape[0:2] + scaled_points = [] + for pt in video_state["input_points"]: + sx = pt[0] / width + sy = pt[1] / height + scaled_points.append([sx, sy]) + + video_state["scaled_points"] = scaled_points + + image_predictor.set_image(video_state["origin_images"][0]) + mask, _, _ = image_predictor.predict( + point_coords=video_state["scaled_points"], + point_labels=video_state["input_labels"], + multimask_output=False, + normalize_coords=False, + ) + + mask = np.squeeze(mask) + mask = cv2.resize(mask, (width, height)) + mask = mask[:, :, None] + + color = ( + np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) + / 255.0 + ) + color = color[None, None, :] + org_image = video_state["origin_images"][0].astype(np.float32) / 255.0 + painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color + painted_image = np.uint8(np.clip(painted_image * 255, 0, 255)) + video_state["painted_images"] = np.expand_dims(painted_image, axis=0) + video_state["masks"] = np.expand_dims(mask[:, :, 0], axis=0) + + for i in range(len(video_state["input_points"])): + point = video_state["input_points"][i] + if video_state["input_labels"][i] == 0: + cv2.circle(painted_image, point, radius=3, color=(0, 0, 255), thickness=-1) + else: + cv2.circle(painted_image, point, radius=3, color=(255, 0, 0), thickness=-1) + + return Image.fromarray(painted_image) + + +def clear_clicks(video_state): + video_state["input_points"] = [] + video_state["input_labels"] = [] + video_state["scaled_points"] = [] + video_state["inference_state"] = None + video_state["masks"] = None + video_state["painted_images"] = None + return ( + Image.fromarray(video_state["origin_images"][0]) + if video_state["origin_images"] is not None + else None + ) + + +def set_ref_image(ref_img, ref_state): + if ref_img is None: + return None + + if isinstance(ref_img, Image.Image): + img_np = np.array(ref_img) + else: + img_np = ref_img + + ref_state["origin_image"] = img_np + ref_state["input_points"] = [] + ref_state["input_labels"] = [] + ref_state["scaled_points"] = [] + ref_state["mask"] = None + + return Image.fromarray(img_np) + + +def segment_ref_frame(evt: gr.SelectData, label, ref_state): + if ref_state["origin_image"] is None: + return None + + x, y = evt.index + new_point = [x, y] + label_value = 1 if label == "Positive" else 0 + + ref_state["input_points"].append(new_point) + ref_state["input_labels"].append(label_value) + + img = ref_state["origin_image"] + h, w = img.shape[:2] + + scaled_points = [] + for pt in ref_state["input_points"]: + sx = pt[0] / w + sy = pt[1] / h + scaled_points.append([sx, sy]) + ref_state["scaled_points"] = scaled_points + + image_predictor.set_image(img) + mask, _, _ = image_predictor.predict( + point_coords=scaled_points, + point_labels=ref_state["input_labels"], + multimask_output=False, + normalize_coords=False, + ) + + mask = np.squeeze(mask) + mask = cv2.resize(mask, (w, h)) + mask = mask[:, :, None] + ref_state["mask"] = mask[:, :, 0] + + color = ( + np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) + / 255.0 + ) + color = color[None, None, :] + org_image = img.astype(np.float32) / 255.0 + painted = (1 - mask * 0.5) * org_image + mask * 0.5 * color + painted = np.uint8(np.clip(painted * 255, 0, 255)) + + for i in range(len(ref_state["input_points"])): + point = ref_state["input_points"][i] + if ref_state["input_labels"][i] == 0: + cv2.circle(painted, point, radius=3, color=(0, 0, 255), thickness=-1) + else: + cv2.circle(painted, point, radius=3, color=(255, 0, 0), thickness=-1) + + return Image.fromarray(painted) + + +def clear_ref_clicks(ref_state): + ref_state["input_points"] = [] + ref_state["input_labels"] = [] + ref_state["scaled_points"] = [] + ref_state["mask"] = None + if ref_state["origin_image"] is None: + return None + return Image.fromarray(ref_state["origin_image"]) + + +@spaces.GPU(duration=40) +def track_video(n_frames, video_state): + input_points = video_state["input_points"] + input_labels = video_state["input_labels"] + frame_idx = video_state["frame_idx"] + obj_id = video_state["obj_id"] + scaled_points = video_state["scaled_points"] + + vr = VideoReader(video_state["video_path"], ctx=cpu(0)) + height, width = vr[0].shape[0:2] + images = [vr[i].asnumpy() for i in range(min(len(vr), n_frames))] + del vr + + if images[0].shape[0] > images[0].shape[1]: + W_ = W + H_ = int(W_ * images[0].shape[0] / images[0].shape[1]) + else: + H_ = H + W_ = int(H_ * images[0].shape[1] / images[0].shape[0]) + + images = [cv2.resize(img, (W_, H_)) for img in images] + video_state["origin_images"] = images + images = np.array(images) + + sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt" + config = "sam2_hiera_l.yaml" + video_predictor_local = build_sam2_video_predictor( + config, sam2_checkpoint, device="cuda" + ) + + inference_state = video_predictor_local.init_state( + images=images / 255, device="cuda" + ) + + if len(torch.from_numpy(video_state["masks"][0]).shape) == 3: + mask = torch.from_numpy(video_state["masks"][0])[:, :, 0] + else: + mask = torch.from_numpy(video_state["masks"][0]) + + video_predictor_local.add_new_mask( + inference_state=inference_state, + frame_idx=0, + obj_id=obj_id, + mask=mask, + ) + + output_frames = [] + mask_frames = [] + color = ( + np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) + / 255.0 + ) + color = color[None, None, :] + for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video( + inference_state + ): + frame = images[out_frame_idx].astype(np.float32) / 255.0 + mask = np.zeros((H, W, 3), dtype=np.float32) + for i, logit in enumerate(out_mask_logits): + out_mask = logit.cpu().squeeze().detach().numpy() + out_mask = (out_mask[:, :, None] > 0).astype(np.float32) + mask += out_mask + mask = np.clip(mask, 0, 1) + mask = cv2.resize(mask, (W_, H_)) + mask_frames.append(mask) + painted = (1 - mask * 0.5) * frame + mask * 0.5 * color + painted = np.uint8(np.clip(painted * 255, 0, 255)) + output_frames.append(painted) + + video_state["masks"] = mask_frames + video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4" + clip = ImageSequenceClip(output_frames, fps=15) + clip.write_videofile( + video_file, codec="libx264", audio=False, verbose=False, logger=None + ) + print("Tracking done") + return video_file, video_state + + +@spaces.GPU(duration=50) +def inference_and_return_video( + dilate_radius, + num_inference_steps, + guidance_scale, + ref_patch_ratio, + fg_threshold, + seed, + video_state, + ref_state, +): + if video_state["origin_images"] is None or video_state["masks"] is None: + print("No video frames or video masks.") + return None, None, None + + if ref_state["origin_image"] is None or ref_state["mask"] is None: + print("Reference image or reference mask missing.") + return None, None, None + + images = video_state["origin_images"] + masks = video_state["masks"] + + video_frames = [] + mask_frames = [] + for img, msk in zip(images, masks): + if not isinstance(img, np.ndarray): + img = np.asarray(img) + img_pil = Image.fromarray(img.astype(np.uint8)) + + if isinstance(msk, np.ndarray): + if msk.ndim == 3: + m2 = msk[..., 0] + else: + m2 = msk + else: + m2 = np.asarray(msk) + + m2 = (m2 > 0.5).astype(np.uint8) * 255 + msk_pil = Image.fromarray(m2, mode="L") + + video_frames.append(img_pil) + mask_frames.append(msk_pil) + + num_frames = len(video_frames) + + h0, w0 = images[0].shape[:2] + if h0 > w0: + height = 832 + width = 480 + else: + height = 480 + width = 832 + + ref_img_np = ref_state["origin_image"] + ref_mask_np = ref_state["mask"] + + ref_img_pil = Image.fromarray(ref_img_np.astype(np.uint8)) + ref_mask_bin = (ref_mask_np > 0.5).astype(np.uint8) * 255 + ref_mask_pil = Image.fromarray(ref_mask_bin, mode="L") + + pipe.to("cuda") + with torch.no_grad(): + retex_frames, mesh_frames, ref_img_out = pipe( + video=video_frames, + mask=mask_frames, + reference_image=ref_img_pil, + reference_mask=ref_mask_pil, + conditioning_scale=1.0, + height=height, + width=width, + num_frames=num_frames, + dilate_radius=int(dilate_radius), + num_inference_steps=int(num_inference_steps), + guidance_scale=float(guidance_scale), + reference_patch_ratio=float(ref_patch_ratio), + fg_thresh=float(fg_threshold), + generator=torch.Generator(device="cuda").manual_seed(seed), + return_dict=False, + ) + + retex_frames_uint8 = (np.clip(retex_frames[0], 0.0, 1.0) * 255).astype(np.uint8) + + mesh_frames_uint8 = (np.clip(mesh_frames[0], 0.0, 1.0) * 255).astype(np.uint8) + + + retex_output_frames = [frame for frame in retex_frames_uint8] + mesh_output_frames = [frame for frame in mesh_frames_uint8] + + if ref_img_out.dtype != np.uint8: + ref_img_out = (np.clip(ref_img_out, 0.0, 1.0) * 255).astype(np.uint8) + + retex_video_file = f"/tmp/{time.time()}-{random.random()}-refacade_output.mp4" + retex_clip = ImageSequenceClip(retex_output_frames, fps=16) + retex_clip.write_videofile( + retex_video_file, codec="libx264", audio=False, verbose=False, logger=None + ) + + mesh_video_file = f"/tmp/{time.time()}-{random.random()}-mesh_output.mp4" + mesh_clip = ImageSequenceClip(mesh_output_frames, fps=16) + mesh_clip.write_videofile( + mesh_video_file, codec="libx264", audio=False, verbose=False, logger=None + ) + + ref_image_to_show = ref_img_out + + return retex_video_file, mesh_video_file, ref_image_to_show + + +text = """ +
+ Refaçade Video Retexture Demo +
+
+ Video mask from SAM2, Reference mask from SAM2 image clicks, RefacadePipeline for object retexture task. +
+""" + +pipe, image_predictor, video_predictor = get_pipe_image_and_video_predictor() + +with gr.Blocks() as demo: + video_state = gr.State( + { + "origin_images": None, + "inference_state": None, + "masks": None, + "painted_images": None, + "video_path": None, + "input_points": [], + "scaled_points": [], + "input_labels": [], + "frame_idx": 0, + "obj_id": 1, + } + ) + + ref_state = gr.State( + { + "origin_image": None, + "input_points": [], + "input_labels": [], + "scaled_points": [], + "mask": None, + } + ) + + gr.Markdown(f"
{text}
") + + with gr.Column(): + video_input = gr.Video(label="Upload Video", elem_id="my-video1") + get_info_btn = gr.Button("Extract First Frame", elem_id="my-btn") + + gr.Examples( + examples=[ + ["./examples/1.mp4"], + ["./examples/2.mp4"], + ["./examples/3.mp4"], + ["./examples/4.mp4"], + ["./examples/5.mp4"], + ["./examples/6.mp4"], + ], + inputs=[video_input], + label="You can upload or choose a source video below to retexture.", + elem_id="my-btn2" + ) + + image_output = gr.Image( + label="First Frame Segmentation", + interactive=True, + elem_id="my-video", + ) + + demo.css = """ + #my-btn { + width: 60% !important; + margin: 0 auto; + } + #my-video1 { + width: 60% !important; + height: 35% !important; + margin: 0 auto; + } + #my-video { + width: 60% !important; + height: 35% !important; + margin: 0 auto; + } + #my-md { + margin: 0 auto; + } + #my-btn2 { + width: 60% !important; + margin: 0 auto; + } + #my-btn2 button { + width: 120px !important; + max-width: 120px !important; + min-width: 120px !important; + height: 70px !important; + max-height: 70px !important; + min-height: 70px !important; + margin: 8px !important; + border-radius: 8px !important; + overflow: hidden !important; + white-space: normal !important; + } + #my-btn3 { + width: 60% !important; + margin: 0 auto; + } + #ref_title { + text-align: center; + } + #ref-image { + width: 60% !important; + height: 35% !important; + margin: 0 auto; + } + #ref-mask { + width: 60% !important; + height: 35% !important; + margin: 0 auto; + } + #mesh-row { + width: 60% !important; + margin: 0 auto; + } + """ + + with gr.Row(elem_id="my-btn"): + point_prompt = gr.Radio( + ["Positive", "Negative"], label="Click Type", value="Positive" + ) + clear_btn = gr.Button("Clear All Clicks") + + with gr.Row(elem_id="my-btn"): + n_frames_slider = gr.Slider( + minimum=1, maximum=201, value=81, step=1, label="Tracking Frames (4N+1)" + ) + track_btn = gr.Button("Tracking") + video_output = gr.Video(label="Tracking Result", elem_id="my-video") + + gr.Markdown("Reference Image & Mask (SAM2 Points)", elem_id="ref_title") + + ref_image_input = gr.Image( + label="Upload Reference Image", elem_id="ref-image", interactive=True + ) + gr.Examples( + examples=[ + ["./examples/reference_image/1.png"], + ["./examples/reference_image/2.png"], + ["./examples/reference_image/3.png"], + ["./examples/reference_image/4.png"], + ["./examples/reference_image/5.png"], + ["./examples/reference_image/6.png"], + ["./examples/reference_image/7.png"], + ["./examples/reference_image/8.png"], + ["./examples/reference_image/9.png"], + ], + inputs=[ref_image_input], + label="You can upload or choose a reference image below to retexture.", + elem_id="my-btn3" + ) + ref_image_display = gr.Image( + label="Reference Mask Segmentation", + elem_id="ref-mask", + interactive=True, + ) + + with gr.Row(elem_id="my-btn"): + ref_point_prompt = gr.Radio( + ["Positive", "Negative"], label="Ref Click Type", value="Positive" + ) + ref_clear_btn = gr.Button("Clear Ref Clicks") + + with gr.Column(elem_id="my-btn"): + + dilate_radius_slider = gr.Slider( + minimum=1, + maximum=10, + value=3, + step=1, + label="Mask Dilation Radius", + ) + inference_steps_slider = gr.Slider( + minimum=1, + maximum=50, + value=20, + step=1, + label="Num Inference Steps", + ) + guidance_slider = gr.Slider( + minimum=1.0, + maximum=3.0, + value=1.5, + step=0.1, + label="Guidance Scale", + ) + ref_patch_slider = gr.Slider( + minimum=0.05, + maximum=1.0, + value=0.1, + step=0.05, + label="Reference Patch Ratio", + ) + fg_threshold_slider = gr.Slider( + minimum=0.7, + maximum=1.0, + value=0.8, + step=0.01, + label="Jigsaw Patches' Foreground Coverage Threshold", + ) + seed_slider = gr.Slider( + minimum=0, + maximum=2147483647, + value=42, + step=1, + label="Seed", + ) + + remove_btn = gr.Button("Retexture", elem_id="my-btn") + + with gr.Row(elem_id="mesh-row"): + mesh_video = gr.Video(label="Untextured Object") + ref_image_final = gr.Image( + label="Jigsawed Reference Image", + interactive=False, + ) + + remove_video = gr.Video(label="Retexture Results", elem_id="my-video") + + remove_btn.click( + inference_and_return_video, + inputs=[ + dilate_radius_slider, + inference_steps_slider, + guidance_slider, + ref_patch_slider, + fg_threshold_slider, + seed_slider, + video_state, + ref_state, + ], + outputs=[remove_video, mesh_video, ref_image_final], + ) + + get_info_btn.click( + get_video_info, + inputs=[video_input, video_state], + outputs=image_output, + ) + + image_output.select( + fn=segment_frame, + inputs=[point_prompt, video_state], + outputs=image_output, + ) + + clear_btn.click(clear_clicks, inputs=video_state, outputs=image_output) + + track_btn.click( + track_video, + inputs=[n_frames_slider, video_state], + outputs=[video_output, video_state], + ) + + ref_image_input.change( + set_ref_image, + inputs=[ref_image_input, ref_state], + outputs=ref_image_display, + ) + ref_image_display.select( + fn=segment_ref_frame, + inputs=[ref_point_prompt, ref_state], + outputs=ref_image_display, + ) + ref_clear_btn.click( + clear_ref_clicks, inputs=ref_state, outputs=ref_image_display + ) + +demo.launch() + diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..dcac726ae8483973a267d307011265841652eec8 --- /dev/null +++ b/pipeline.py @@ -0,0 +1,718 @@ +from typing import Any, Callable, Dict, List, Optional, Union +import PIL.Image +import torch +import math +import random +import numpy as np +import torch.nn.functional as F +from typing import Tuple +from PIL import Image + +from vae import WanVAE +from vace.models.wan.modules.model_mm import VaceMMModel +from vace.models.wan.modules.model_tr import VaceWanModel + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput +from diffusers.loaders import WanLoraLoaderMixin +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput +from dataclasses import dataclass + + +@dataclass +class RefacadePipelineOutput(BaseOutput): + frames: torch.Tensor + meshes: torch.Tensor + ref_img: torch.Tensor + + +logger = logging.get_logger(__name__) + + +@torch.no_grad() +def _pad_to_multiple(x: torch.Tensor, multiple: int, mode: str = "reflect"): + H, W = x.shape[-2], x.shape[-1] + pad_h = (multiple - H % multiple) % multiple + pad_w = (multiple - W % multiple) % multiple + pad = (0, pad_w, 0, pad_h) + if pad_h or pad_w: + x = F.pad(x, pad, mode=mode) + return x, pad + + +@torch.no_grad() +def _unpad(x: torch.Tensor, pad): + l, r, t, b = pad + H, W = x.shape[-2], x.shape[-1] + return x[..., t:H - b if b > 0 else H, l:W - r if r > 0 else W] + + +@torch.no_grad() +def _resize(x: torch.Tensor, size: tuple, is_mask: bool): + mode = "nearest" if is_mask else "bilinear" + if is_mask: + return F.interpolate(x, size=size, mode=mode) + else: + return F.interpolate(x, size=size, mode=mode, align_corners=False) + + +@torch.no_grad() +def _center_scale_foreground_to_canvas( + x_f: torch.Tensor, + m_f: torch.Tensor, + target_hw: tuple, + bg_value: float = 1.0, +): + C, H, W = x_f.shape + H2, W2 = target_hw + device = x_f.device + ys, xs = (m_f > 0.5).nonzero(as_tuple=True) + canvas = torch.full((C, H2, W2), bg_value, dtype=x_f.dtype, device=device) + mask_canvas = torch.zeros((1, H2, W2), dtype=x_f.dtype, device=device) + if ys.numel() == 0: + return canvas, mask_canvas + + y0, y1 = ys.min().item(), ys.max().item() + x0, x1 = xs.min().item(), xs.max().item() + crop_img = x_f[:, y0:y1 + 1, x0:x1 + 1] + crop_msk = m_f[y0:y1 + 1, x0:x1 + 1].unsqueeze(0) + hc, wc = crop_msk.shape[-2], crop_msk.shape[-1] + s = min(H2 / max(1, hc), W2 / max(1, wc)) + Ht = max(1, min(H2, int(math.floor(hc * s)))) + Wt = max(1, min(W2, int(math.floor(wc * s)))) + crop_img_up = _resize(crop_img.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0) + crop_msk_up = _resize(crop_msk.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0) + crop_msk_up = (crop_msk_up > 0.5).to(crop_msk_up.dtype) + + top = (H2 - Ht) // 2 + left = (W2 - Wt) // 2 + canvas[:, top:top + Ht, left:left + Wt] = crop_img_up + mask_canvas[:, top:top + Ht, left:left + Wt] = crop_msk_up + return canvas, mask_canvas + + +@torch.no_grad() +def _sample_patch_size_from_hw( + H: int, + W: int, + ratio: float = 0.2, + min_px: int = 16, + max_px: Optional[int] = None, +) -> int: + r = ratio + raw = r * min(H, W) + if max_px is None: + max_px = min(192, min(H, W)) + P = int(round(raw)) + P = max(min_px, min(P, max_px)) + P = int(P) + return P + + +@torch.no_grad() +def _masked_patch_pack_to_center_rectangle( + x_f: torch.Tensor, + m_f: torch.Tensor, + patch: int, + fg_thresh: float = 0.8, + bg_value: float = 1.0, + min_patches: int = 4, + flip_prob: float = 0.5, + use_morph_erode: bool = False, +): + + C, H, W = x_f.shape + device = x_f.device + P = int(patch) + + x_pad, pad = _pad_to_multiple(x_f, P, mode="reflect") + l, r, t, b = pad + H2, W2 = x_pad.shape[-2], x_pad.shape[-1] + m_pad = F.pad(m_f.unsqueeze(0).unsqueeze(0), (l, r, t, b), mode="constant", value=0.0).squeeze(0) + + cs_img, cs_msk = _center_scale_foreground_to_canvas(x_pad, m_pad.squeeze(0), (H2, W2), bg_value) + if (cs_msk > 0.5).sum() == 0: + out_img = _unpad(cs_img, pad).clamp_(-1, 1) + out_msk = _unpad(cs_msk, pad).clamp_(0, 1) + return out_img, out_msk, True + + m_eff = cs_msk + if use_morph_erode: + erode_px = int(max(1, min(6, round(P * 0.03)))) + m_eff = 1.0 - F.max_pool2d(1.0 - cs_msk, kernel_size=2 * erode_px + 1, stride=1, padding=erode_px) + + x_pad2, pad2 = _pad_to_multiple(cs_img, P, mode="reflect") + m_pad2 = F.pad(m_eff, pad2, mode="constant", value=0.0) + H3, W3 = x_pad2.shape[-2], x_pad2.shape[-1] + + m_pool = F.avg_pool2d(m_pad2, kernel_size=P, stride=P).view(-1) + + base_thr = float(fg_thresh) + thr_candidates = [base_thr, max(base_thr - 0.05, 0.75), max(base_thr - 0.10, 0.60)] + + x_unf = F.unfold(x_pad2.unsqueeze(0), kernel_size=P, stride=P) + N = x_unf.shape[-1] + + sel = None + for thr in thr_candidates: + idx = (m_pool >= (thr - 1e-6)).nonzero(as_tuple=False).squeeze(1) + if idx.numel() >= min_patches: + sel = idx + break + if sel is None: + img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) + msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) + return img_fallback, msk_fallback, True + + sel = sel.to(device=device, dtype=torch.long) + sel = sel[(sel >= 0) & (sel < N)] + if sel.numel() == 0: + img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) + msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) + return img_fallback, msk_fallback, True + + perm = torch.randperm(sel.numel(), device=device, dtype=torch.long) + sel = sel[perm] + chosen_x = x_unf[:, :, sel] + K = chosen_x.shape[-1] + if K == 0: + img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) + msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) + return img_fallback, msk_fallback, True + + if flip_prob > 0: + cx4 = chosen_x.view(1, C, P, P, K) + do_flip = (torch.rand(K, device=device) < flip_prob) + coin = (torch.rand(K, device=device) < 0.5) + flip_h = do_flip & coin + flip_v = do_flip & (~coin) + if flip_h.any(): + cx4[..., flip_h] = cx4[..., flip_h].flip(dims=[3]) + if flip_v.any(): + cx4[..., flip_v] = cx4[..., flip_v].flip(dims=[2]) + chosen_x = cx4.view(1, C * P * P, K) + + max_cols = max(1, W3 // P) + max_rows = max(1, H3 // P) + capacity = max_rows * max_cols + K_cap = min(K, capacity) + cols = int(max(1, min(int(math.floor(math.sqrt(K_cap))), max_cols))) + rows_full = min(max_rows, K_cap // cols) + K_used = rows_full * cols + if K_used == 0: + img_fallback = _unpad(_unpad(cs_img, pad2), pad).clamp_(-1, 1) + msk_fallback = _unpad(_unpad(cs_msk, pad2), pad).clamp_(0, 1) + return img_fallback, msk_fallback, True + + chosen_x = chosen_x[:, :, :K_used] + rect_unf = torch.full((1, C * P * P, rows_full * cols), bg_value, device=device, dtype=x_f.dtype) + rect_unf[:, :, :K_used] = chosen_x + rect = F.fold(rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0) + + ones_patch = torch.ones((1, 1 * P * P, K_used), device=device, dtype=x_f.dtype) + mask_rect_unf = torch.zeros((1, 1 * P * P, rows_full * cols), device=device, dtype=x_f.dtype) + mask_rect_unf[:, :, :K_used] = ones_patch + rect_mask = F.fold(mask_rect_unf, output_size=(rows_full * P, cols * P), kernel_size=P, stride=P).squeeze(0) + + Hr, Wr = rect.shape[-2], rect.shape[-1] + s = min(H3 / max(1, Hr), W3 / max(1, Wr)) + Ht = min(max(1, int(math.floor(Hr * s))), H3) + Wt = min(max(1, int(math.floor(Wr * s))), W3) + + rect_up = _resize(rect.unsqueeze(0), (Ht, Wt), is_mask=False).squeeze(0) + rect_mask_up = _resize(rect_mask.unsqueeze(0), (Ht, Wt), is_mask=True).squeeze(0) + + canvas_x = torch.full((C, H3, W3), bg_value, device=device, dtype=x_f.dtype) + canvas_m = torch.zeros((1, H3, W3), device=device, dtype=x_f.dtype) + top, left = (H3 - Ht) // 2, (W3 - Wt) // 2 + canvas_x[:, top:top + Ht, left:left + Wt] = rect_up + canvas_m[:, top:top + Ht, left:left + Wt] = rect_mask_up + + out_img = _unpad(_unpad(canvas_x, pad2), pad).clamp_(-1, 1) + out_msk = _unpad(_unpad(canvas_m, pad2), pad).clamp_(0, 1) + return out_img, out_msk, False + + +@torch.no_grad() +def _compose_centered_foreground(x_f: torch.Tensor, m_f3: torch.Tensor, target_hw: Tuple[int, int], bg_value: float = 1.0): + m_bin = (m_f3 > 0.5).float().mean(dim=0) + m_bin = (m_bin > 0.5).float() + return _center_scale_foreground_to_canvas(x_f, m_bin, target_hw, bg_value) + +class RefacadePipeline(DiffusionPipeline, WanLoraLoaderMixin): + + model_cpu_offload_seq = "texture_remover->transformer->vae" + + def __init__( + self, + vae, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer: VaceMMModel = None, + texture_remover: VaceWanModel = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + texture_remover=texture_remover, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.empty_embedding = torch.load( + "/ms/AIGC/huangyouze/fish_pipeline/stage10_t5_encode/image/empty.pt", + map_location="cpu" + ) + self.negative_embedding = torch.load( + "/ms/AIGC/huangyouze/fish_pipeline/stage10_t5_encode/image/negative.pt", + map_location="cpu" + ) + + def vace_encode_masks(self, masks: torch.Tensor): + masks = masks[:, :1, :, :, :] + B, C, D, H, W = masks.shape + patch_h, patch_w = self.vae_scale_factor_spatial, self.vae_scale_factor_spatial + stride_t = self.vae_scale_factor_temporal + patch_count = patch_h * patch_w + new_D = (D + stride_t - 1) // stride_t + new_H = 2 * (H // (patch_h * 2)) + new_W = 2 * (W // (patch_w * 2)) + masks = masks[:, 0] + masks = masks.view(B, D, new_H, patch_h, new_W, patch_w) + masks = masks.permute(0, 3, 5, 1, 2, 4) + masks = masks.reshape(B, patch_count, D, new_H, new_W) + masks = F.interpolate( + masks, + size=(new_D, new_H, new_W), + mode="nearest-exact" + ) + return masks + + def preprocess_conditions( + self, + video: Optional[List[PipelineImageInput]] = None, + mask: Optional[List[PipelineImageInput]] = None, + reference_image: Optional[PIL.Image.Image] = None, + reference_mask: Optional[PIL.Image.Image] = None, + batch_size: int = 1, + height: int = 480, + width: int = 832, + num_frames: int = 81, + reference_patch_ratio: float = 0.2, + fg_thresh: float = 0.9, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + + base = self.vae_scale_factor_spatial * 2 + video_height, video_width = self.video_processor.get_default_height_width(video[0]) + + if video_height * video_width > height * width: + scale_w = width / video_width + scale_h = height / video_height + video_height, video_width = int(video_height * scale_h), int(video_width * scale_w) + + if video_height % base != 0 or video_width % base != 0: + logger.warning( + f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. " + ) + video_height = (video_height // base) * base + video_width = (video_width // base) * base + + assert video_height * video_width <= height * width + + video = self.video_processor.preprocess_video(video, video_height, video_width) + image_size = (video_height, video_width) + + mask = self.video_processor.preprocess_video(mask, video_height, video_width) + mask = torch.clamp((mask + 1) / 2, min=0, max=1) + + video = video.to(dtype=dtype, device=device) + mask = mask.to(dtype=dtype, device=device) + + if reference_image is None: + raise ValueError("reference_image must be provided when using IMAGE_CONTROL mode.") + + if isinstance(reference_image, (list, tuple)): + ref_img_pil = reference_image[0] + else: + ref_img_pil = reference_image + + if reference_mask is not None and isinstance(reference_mask, (list, tuple)): + ref_mask_pil = reference_mask[0] + else: + ref_mask_pil = reference_mask + + ref_img_t = self.video_processor.preprocess(ref_img_pil, image_size[0], image_size[1]) + if ref_img_t.dim() == 4 and ref_img_t.shape[0] == 1: + ref_img_t = ref_img_t[0] + if ref_img_t.shape[0] == 1: + ref_img_t = ref_img_t.repeat(3, 1, 1) + ref_img_t = ref_img_t.to(dtype=dtype, device=device) + + H, W = image_size + if ref_mask_pil is not None: + if not isinstance(ref_mask_pil, Image.Image): + ref_mask_pil = Image.fromarray(np.array(ref_mask_pil)) + ref_mask_pil = ref_mask_pil.convert("L") + ref_mask_pil = ref_mask_pil.resize((W, H), Image.NEAREST) + mask_arr = np.array(ref_mask_pil) + m = torch.from_numpy(mask_arr).float() / 255.0 + m = (m > 0.5).float() + ref_msk3 = m.unsqueeze(0).repeat(3, 1, 1) + else: + ref_msk3 = torch.ones(3, H, W, dtype=dtype) + + ref_msk3 = ref_msk3.to(dtype=dtype, device=device) + + if math.isclose(reference_patch_ratio, 1.0, rel_tol=1e-6, abs_tol=1e-6): + cs_img, cs_m = _compose_centered_foreground( + x_f=ref_img_t, + m_f3=ref_msk3, + target_hw=image_size, + bg_value=1.0, + ) + ref_img_out = cs_img + ref_mask_out = cs_m + else: + patch = _sample_patch_size_from_hw( + H=image_size[0], + W=image_size[1], + ratio=reference_patch_ratio, + ) + + m_bin = (ref_msk3 > 0.5).float().mean(dim=0) + m_bin = (m_bin > 0.5).float() + reshuffled, reshuf_mask, used_fb = _masked_patch_pack_to_center_rectangle( + x_f=ref_img_t, + m_f=m_bin, + patch=patch, + fg_thresh=fg_thresh, + bg_value=1.0, + min_patches=4, + ) + + ref_img_out = reshuffled + ref_mask_out = reshuf_mask + + B = video.shape[0] + if batch_size is not None: + B = batch_size + + ref_image = ref_img_out.unsqueeze(0).unsqueeze(2).expand(B, -1, -1, -1, -1).contiguous() + ref_mask = ref_mask_out.unsqueeze(0).unsqueeze(2).expand(B, 3, -1, -1, -1).contiguous() + + ref_image = ref_image.to(dtype=dtype, device=device) + ref_mask = ref_mask.to(dtype=dtype, device=device) + + return video[:, :, :num_frames], mask[:, :, :num_frames], ref_image, ref_mask + + @torch.no_grad() + def texture_remove(self, foreground_latent): + sample_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=1) + text_embedding = torch.zeros( + [256, 4096], + device=foreground_latent.device, + dtype=foreground_latent.dtype + ) + context = text_embedding.unsqueeze(0).expand( + foreground_latent.shape[0], -1, -1 + ).to(foreground_latent.device) + sample_scheduler.set_timesteps(3, device=foreground_latent.device) + timesteps = sample_scheduler.timesteps + noise = torch.randn_like( + foreground_latent, + dtype=foreground_latent.dtype, + device=foreground_latent.device + ) + seq_len = math.ceil( + noise.shape[2] * noise.shape[3] * noise.shape[4] / 4 + ) + latents = noise + arg_c = {"context": context, "seq_len": seq_len} + with torch.autocast(device_type="cuda", dtype=torch.float16): + for _, t in enumerate(timesteps): + timestep = torch.stack([t]).to(foreground_latent.device) + noise_pred_cond = self.texture_remover( + latents, + t=timestep, + vace_context=foreground_latent, + vace_context_scale=1, + **arg_c + )[0] + temp_x0 = sample_scheduler.step( + noise_pred_cond, t, latents, return_dict=False + )[0] + latents = temp_x0 + return latents + + def dilate_mask_hw(self, mask: torch.Tensor, radius: int = 3) -> torch.Tensor: + B, C, F_, H, W = mask.shape + k = 2 * radius + 1 + mask_2d = mask.permute(0, 2, 1, 3, 4).reshape(B * F_, C, H, W) + kernel = torch.ones( + (C, 1, k, k), + device=mask.device, + dtype=mask.dtype + ) + dilated_2d = F.conv2d( + mask_2d, + weight=kernel, + bias=None, + stride=1, + padding=radius, + groups=C + ) + dilated_2d = (dilated_2d > 0).to(mask.dtype) + dilated = dilated_2d.view(B, F_, C, H, W).permute(0, 2, 1, 3, 4) + return dilated + + def prepare_vace_latents( + self, + dilate_radius: int, + video: torch.Tensor, + mask: torch.Tensor, + reference_image: Optional[torch.Tensor] = None, + reference_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + device = device or self._execution_device + + vae_dtype = self.vae.dtype + video = video.to(dtype=vae_dtype) + mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) + mask_clone = mask.clone() + mask = self.dilate_mask_hw(mask, dilate_radius) + inactive = video * (1 - mask) + reactive = video * mask_clone + reactive_latent = self.vae.encode(reactive) + mesh_latent = self.texture_remove(reactive_latent) + + inactive_latent = self.vae.encode(inactive) + ref_latent = self.vae.encode(reference_image) + neg_ref_latent = self.vae.encode(torch.ones_like(reference_image)) + + reference_mask = torch.where(reference_mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) + mask = self.vace_encode_masks(mask) + ref_mask = self.vace_encode_masks(reference_mask) + + return inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask + + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @torch.no_grad() + def __call__( + self, + video: Optional[PipelineImageInput] = None, + mask: Optional[PipelineImageInput] = None, + reference_image: Optional[PipelineImageInput] = None, + reference_mask: Optional[PipelineImageInput] = None, + conditioning_scale: float = 1.0, + dilate_radius: int = 3, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 20, + guidance_scale: float = 1.5, + num_videos_per_prompt: Optional[int] = 1, + reference_patch_ratio: float = 0.2, + fg_thresh: float = 0.9, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + ): + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + + self._guidance_scale = guidance_scale + + device = self._execution_device + batch_size = 1 + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + video, mask, reference_image, reference_mask = self.preprocess_conditions( + video, + mask, + reference_image, + reference_mask, + batch_size, + height, + width, + num_frames, + reference_patch_ratio, + fg_thresh, + torch.float16, + device, + ) + + inactive_latent, mesh_latent, ref_latent, neg_ref_latent, mask, ref_mask = self.prepare_vace_latents(dilate_radius, video, mask, reference_image, reference_mask, device) + c = torch.cat([inactive_latent, mesh_latent, mask], dim=1) + c1 = torch.cat([ref_latent, ref_mask], dim=1) + c1_negative = torch.cat( + [neg_ref_latent, torch.zeros_like(ref_mask)], + dim=1 + ) + + num_channels_latents = 16 + noise = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float16, + device, + generator, + latents, + ) + + latents_cond = torch.cat([ref_latent, noise], dim=2) + latents_uncond = torch.cat([neg_ref_latent, noise], dim=2) + + seq_len = math.ceil( + latents_cond.shape[2] * + latents_cond.shape[3] * + latents_cond.shape[4] / 4 + ) + seq_len_ref = math.ceil( + ref_latent.shape[2] * + ref_latent.shape[3] * + ref_latent.shape[4] / 4 + ) + context = self.empty_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device) + context_neg = self.negative_embedding.unsqueeze(0).expand(batch_size, -1, -1).to(device) + arg_c = { + "context": context, + "seq_len": seq_len, + "seq_len_ref": seq_len_ref + } + arg_c_null = { + "context": context_neg, + "seq_len": seq_len, + "seq_len_ref": seq_len_ref + } + + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self._current_timestep = t + timestep = t.expand(batch_size) + + with torch.autocast(device_type="cuda", dtype=torch.float16): + noise_pred = self.transformer( + latents_cond, + t=timestep, + vace_context=c, + ref_context=c1, + vace_context_scale=conditioning_scale, + **arg_c, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + latents_uncond, + t=timestep, + vace_context=c, + ref_context=c1_negative, + vace_context_scale=0, + **arg_c_null, + )[0] + noise_pred = (noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)).unsqueeze(0) + temp_x0 = self.scheduler.step(noise_pred[:, :, 1:], + t, + latents_cond[:, :, 1:], + return_dict=False)[0] + latents_cond = torch.cat([ref_latent, temp_x0], dim=2) + latents_uncond = torch.cat([neg_ref_latent, temp_x0], dim=2) + progress_bar.update() + + + self._current_timestep = None + + if not output_type == "latent": + latents = temp_x0 + latents = latents.to(vae_dtype) + video = self.vae.decode(latents) + video = self.video_processor.postprocess_video(video, output_type=output_type) + mesh = self.vae.decode(mesh_latent.to(vae_dtype)) + mesh = self.video_processor.postprocess_video(mesh, output_type=output_type) + ref_img = reference_image.cpu().squeeze(0).squeeze(1).permute(1, 2, 0).numpy() + ref_img = ((ref_img+1)*255/2).astype(np.uint8) + else: + video = temp_x0 + mesh = mesh_latent + ref_img = ref_latent + + self.maybe_free_model_hooks() + + if not return_dict: + return (video, mesh, ref_img) + + return RefacadePipelineOutput(frames=video, meshes=mesh, ref_img=ref_img) diff --git a/sam2/SAM2-Video-Predictor/checkpoints/README.md b/sam2/SAM2-Video-Predictor/checkpoints/README.md new file mode 100644 index 0000000000000000000000000000000000000000..87274582326ed60c599f54b462b76da02085f813 --- /dev/null +++ b/sam2/SAM2-Video-Predictor/checkpoints/README.md @@ -0,0 +1,58 @@ +--- +license: apache-2.0 +pipeline_tag: mask-generation +library_name: sam2 +--- + +Repository for SAM 2: Segment Anything in Images and Videos, a foundation model towards solving promptable visual segmentation in images and videos from FAIR. See the [SAM 2 paper](https://arxiv.org/abs/2408.00714) for more information. + +The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/). + +## Usage + +For image prediction: + +```python +import torch +from sam2.sam2_image_predictor import SAM2ImagePredictor + +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +For video prediction: + +```python +import torch +from sam2.sam2_video_predictor import SAM2VideoPredictor + +predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... +``` + +Refer to the [demo notebooks](https://github.com/facebookresearch/segment-anything-2/tree/main/notebooks) for details. + +### Citation + +To cite the paper, model, or software, please use the below: +``` +@article{ravi2024sam2, + title={SAM 2: Segment Anything in Images and Videos}, + author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, + journal={arXiv preprint arXiv:2408.00714}, + url={https://arxiv.org/abs/2408.00714}, + year={2024} +} +``` \ No newline at end of file diff --git a/sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_l.yaml b/sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b92b5b3a7392b627a74c46f29195ffdaab82d82 --- /dev/null +++ b/sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False \ No newline at end of file diff --git a/sam2/__init__.py b/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d05adb1c40ab94d5eaa10da430ce9e9fd0e62bc9 --- /dev/null +++ b/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize + +from .build_sam import load_model + +initialize("configs", version_base="1.2") diff --git a/sam2/__pycache__/__init__.cpython-311.pyc b/sam2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4027f5036ad5a33fae0b2858d8863728e8b23b7e Binary files /dev/null and b/sam2/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/__pycache__/build_sam.cpython-311.pyc b/sam2/__pycache__/build_sam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de0ffe3d050017efdd8e10d642681c2ae2eab275 Binary files /dev/null and b/sam2/__pycache__/build_sam.cpython-311.pyc differ diff --git a/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc b/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66ca05f50322d7753a239f062900daabe18a6f75 Binary files /dev/null and b/sam2/__pycache__/sam2_image_predictor.cpython-311.pyc differ diff --git a/sam2/__pycache__/sam2_video_predictor.cpython-311.pyc b/sam2/__pycache__/sam2_video_predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34b84f2c5406a66130071c7480ce5aa94f88eccb Binary files /dev/null and b/sam2/__pycache__/sam2_video_predictor.cpython-311.pyc differ diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..37c12f48cc0e218fb1634431590bfaf93a1151e9 --- /dev/null +++ b/sam2/automatic_mask_generator.py @@ -0,0 +1,434 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SAM2AutomaticMaskGenerator: + def __init__( + self, + model: SAM2Base, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.8, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + mask_threshold: float = 0.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + use_m2m: bool = False, + multimask_output: bool = True, + ) -> None: + """ + Using a SAM 2 model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM 2 with a HieraL backbone. + + Arguments: + model (Sam): The SAM 2 model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + mask_threshold (float): Threshold for binarizing the mask logits + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + use_m2m (bool): Whether to add a one step refinement using previous mask predictions. + multimask_output (bool): Whether to output multimask at each point of the grid. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + try: + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + except ImportError as e: + print("Please install pycocotools") + raise e + + self.predictor = SAM2ImagePredictor( + model, + max_hole_area=min_mask_region_area, + max_sprinkle_area=min_mask_region_area, + ) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.mask_threshold = mask_threshold + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.use_m2m = use_m2m + self.multimask_output = multimask_output + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size, normalize=True + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + points = torch.as_tensor(points, device=self.predictor.device) + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): + new_masks = [] + new_iou_preds = [] + + for cur_points, cur_point_labels, low_res_mask in batch_iterator( + points_per_batch, points, point_labels, low_res_masks + ): + best_masks, best_iou_preds, _ = self.predictor._predict( + cur_points[:, None, :], + cur_point_labels[:, None], + mask_input=low_res_mask[:, None, :], + multimask_output=False, + return_logits=True, + ) + new_masks.append(best_masks) + new_iou_preds.append(best_iou_preds) + masks = torch.cat(new_masks, dim=0) + return masks, torch.cat(new_iou_preds, dim=0) diff --git a/sam2/build_sam.py b/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..ab2d0922aa3d8f869afb20b27748d98304e175f0 --- /dev/null +++ b/sam2/build_sam.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + +from .utils.misc import VARIANTS, variant_to_config_mapping + + +def load_model( + variant: str, + ckpt_path=None, + device="cpu", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, +) -> torch.nn.Module: + assert variant in VARIANTS, f"only accepted variants are {VARIANTS}" + + return build_sam2( + config_file=variant_to_config_mapping[variant], + ckpt_path=ckpt_path, + device=device, + mode=mode, + hydra_overrides_extra=hydra_overrides_extra, + apply_postprocessing=apply_postprocessing, + ) + + +def build_sam2( + config_file, + ckpt_path=None, + device="cpu", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cpu", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, +): + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + # "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/sam2/configs/__init__.py b/sam2/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2/configs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/configs/sam2_hiera_b+.yaml b/sam2/configs/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58f3eb81554018e873f8515ecb98e36d16ac29e4 --- /dev/null +++ b/sam2/configs/sam2_hiera_b+.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2_hiera_l.yaml b/sam2/configs/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..918667f50c3e1ad2dcf77c0c14cb4dd114cfd080 --- /dev/null +++ b/sam2/configs/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2_hiera_s.yaml b/sam2/configs/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2 --- /dev/null +++ b/sam2/configs/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2/configs/sam2_hiera_t.yaml b/sam2/configs/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a62c903aaa5f80828077c6e06a59626926570ed6 --- /dev/null +++ b/sam2/configs/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False diff --git a/sam2/modeling/__init__.py b/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b2b5255f11a31b524f22635c852e7deed04ae8b Binary files /dev/null and b/sam2/modeling/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc b/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a9d8fca6e3c9fc260526640253d3d338c7ef2c1 Binary files /dev/null and b/sam2/modeling/__pycache__/memory_attention.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc b/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aebed0f43913e0d424a78d3e46a3c05e8897f04 Binary files /dev/null and b/sam2/modeling/__pycache__/memory_encoder.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc b/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6508fe8bb4ddfa05b4c689d39ce2de6404514ea3 Binary files /dev/null and b/sam2/modeling/__pycache__/position_encoding.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc b/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..380ff5fdbfc7a3ac1e45ccf8d1b26b8d00ce4a8f Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_base.cpython-311.pyc differ diff --git a/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc b/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8efda8a2e63f373e68d87dba7f8733fa5c0f0d27 Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_utils.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/__init__.py b/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45f275aa7a5bb7b0a43e1801dd1e41567cca7e36 Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9952a40a7ebd152ca403871df2fa5bfe0ab120e4 Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/hieradet.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47478d488bd0b95c825e899e332481e5e8158d29 Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc b/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe288f242b749edd87bba50c9d90b922b9040a1c Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/utils.cpython-311.pyc differ diff --git a/sam2/modeling/backbones/hieradet.py b/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6d3b9fc82ac13ff1bde50312eb1cce517eb776 --- /dev/null +++ b/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,294 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) +from sam2.modeling.sam2_utils import MLP, DropPath + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs diff --git a/sam2/modeling/backbones/image_encoder.py b/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5f92baf47dcab96385ff99899fd3e3a642c1cf9c --- /dev/null +++ b/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7 --- /dev/null +++ b/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/sam2/modeling/memory_attention.py b/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..00d7083200e2d3a1c378f5bec5239ebcd9036f18 --- /dev/null +++ b/sam2/modeling/memory_attention.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import Tensor, nn + +from sam2.modeling.sam2_utils import get_activation_fn, get_clones +from sam2.modeling.sam.transformer import RoPEAttention + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/sam2/modeling/memory_encoder.py b/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b2df7aa754cd1a69e6231aad16ebaa10214411 --- /dev/null +++ b/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..4a308c30b69fd1a216fb1159b69261688a1792f6 --- /dev/null +++ b/sam2/modeling/position_encoding.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/sam2/modeling/sam/__init__.py b/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc b/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48ed33b9cb928c69f169915fe938e85609f56614 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aab40a2eddf5f8d5fe0aa5f900c38672b1da744 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-311.pyc differ diff --git a/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe20ea8edd26359ad5441798a7dfd34fb71a245a Binary files /dev/null and b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-311.pyc differ diff --git a/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc b/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a11939782e033b2acef84e48e2fa37588a0af65 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/transformer.cpython-311.pyc differ diff --git a/sam2/modeling/sam/mask_decoder.py b/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..29b488a57b05dcd85df714e73b37524a2e8897c8 --- /dev/null +++ b/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.sam2_utils import MLP, LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e177a71a32b4ca9811f7b7535de50d988b3ad492 --- /dev/null +++ b/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.position_encoding import PositionEmbeddingRandom +from sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..85baf152f418610eb0dcfa220f38dfc99ce53213 --- /dev/null +++ b/sam2/modeling/sam/transformer.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from sam2.modeling.sam2_utils import MLP +from sam2.utils.misc import get_sdp_backends + +warnings.simplefilter(action="ignore", category=FutureWarning) +# OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + + #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + + #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..4326f448de8c9a652f0a841dad9620579b7976f4 --- /dev/null +++ b/sam2/modeling/sam2_base.py @@ -0,0 +1,828 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + +from sam2.modeling.sam2_utils import MLP, get_1d_sine_pe, select_closest_cond_frames +from sam2.modeling.sam.mask_decoder import MaskDecoder +from sam2.modeling.sam.prompt_encoder import PromptEncoder +from sam2.modeling.sam.transformer import TwoWayTransformer + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + + self._build_sam_heads() + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference." + "See notebooks/video_predictor_example.ipynb for an example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + # Only hard possible with gt + assert not self.teacher_force_obj_scores_for_mem + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(self.device) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(self.device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + (abs(frame_idx - t), out["obj_ptr"]) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + return maskmem_features, maskmem_pos_enc + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + _, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/sam2/modeling/sam2_utils.py b/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9705963efc57d74b7d1bff31692d7d293a46ad --- /dev/null +++ b/sam2/modeling/sam2_utils.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..6444b0f225d7a99774e51e587cd9fa7d9f71ec39 --- /dev/null +++ b/sam2/sam2_image_predictor.py @@ -0,0 +1,445 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + ) -> None: + """ + Uses SAM-2 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam-2): The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to + the maximum area of fill_hole_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + print("feat_size", self._bb_feat_sizes[::-1]) + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tupele of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b911105a395fd759a0bdef5ef4574f0e7bf301 --- /dev/null +++ b/sam2/sam2_video_predictor.py @@ -0,0 +1,900 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + + @torch.inference_mode() + def init_state( + self, + video_path=None, + images=None, + device="cpu", + async_loading_frames=False, + ): + """Initialize a inference state.""" + if images is not None: + images, video_height, video_width = load_video_frames( + video_path=None, + images=images, + image_size=self.image_size, + async_loading_frames=async_loading_frames, + device=device, + ) + else: + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + async_loading_frames=async_loading_frames, + device=device, + ) + inference_state = dict() + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = device + inference_state["storage_device"] = device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points( + self, + inference_state, + frame_idx, + obj_id, + points, + labels, + clear_old_points=True, + normalize_coords=True, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(inference_state["device"]) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + image = ( + inference_state["images"][frame_idx] + .to(inference_state["device"]) + .float() + .unsqueeze(0) + ) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/sam2/utils/__init__.py b/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d9f56d9ac50c1c34d742f757c488b06b0c2b4e --- /dev/null +++ b/sam2/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .download import download_weights diff --git a/sam2/utils/__pycache__/__init__.cpython-311.pyc b/sam2/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f420b43150fff84bc29294a177347c686886ae8e Binary files /dev/null and b/sam2/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/sam2/utils/__pycache__/download.cpython-311.pyc b/sam2/utils/__pycache__/download.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c347a06cb6c1b0f14fded33a11ffdef72738f50 Binary files /dev/null and b/sam2/utils/__pycache__/download.cpython-311.pyc differ diff --git a/sam2/utils/__pycache__/misc.cpython-311.pyc b/sam2/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65e5d90f57b978985524fa52926f6151f4550ff4 Binary files /dev/null and b/sam2/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/sam2/utils/__pycache__/transforms.cpython-311.pyc b/sam2/utils/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ffee0d81bb9a8bcc616c68882c205364d2c5b0c Binary files /dev/null and b/sam2/utils/__pycache__/transforms.cpython-311.pyc differ diff --git a/sam2/utils/amg.py b/sam2/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e --- /dev/null +++ b/sam2/utils/amg.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + +# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.float().detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/sam2/utils/download.py b/sam2/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..577a0b16e6bba6a70dd94cc9360bc21ce5fe4586 --- /dev/null +++ b/sam2/utils/download.py @@ -0,0 +1,36 @@ +import os +import subprocess +from typing import List + +import pytest + + +@pytest.fixture +def download_weights(output_directory: str = "artifacts") -> None: + base_url: str = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" + file_names: List[str] = [ + "sam2_hiera_tiny.pt", + "sam2_hiera_small.pt", + "sam2_hiera_base_plus.pt", + "sam2_hiera_large.pt", + ] + + if not os.path.exists(output_directory): + os.makedirs(output_directory) + + for file_name in file_names: + file_path = os.path.join(output_directory, file_name) + if not os.path.exists(file_path): + url = f"{base_url}{file_name}" + command = ["wget", url, "-P", output_directory] + try: + result = subprocess.run( + command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + print(f"Download of {file_name} completed successfully.") + print(result.stdout.decode()) + except subprocess.CalledProcessError as e: + print(f"An error occurred during the download of {file_name}.") + print(e.stderr.decode()) + else: + print(f"{file_name} already exists. Skipping download.") diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b860df46c33b8377a0125a581ec17f61d424766f --- /dev/null +++ b/sam2/utils/misc.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread +from typing import Dict, List, Union + +import numpy as np +import torch +from PIL import Image +from torch.nn.attention import SDPBackend +from einops import rearrange +from tqdm import tqdm +import torch.nn.functional as F + +VARIANTS: List[str] = ["tiny", "small", "base_plus", "large"] + +variant_to_config_mapping: Dict[str, str] = { + "tiny": "sam2_hiera_t.yaml", + "small": "sam2_hiera_s.yaml", + "base_plus": "sam2_hiera_b+.yaml", + "large": "sam2_hiera_l.yaml", +} + + +def get_sdp_backends(dropout_p: float) -> Union[List[SDPBackend], SDPBackend]: + backends = [] + if torch.cuda.is_available(): + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + + if torch.cuda.get_device_properties(0).major < 7: + backends.append(SDPBackend.EFFICIENT_ATTENTION) + + if use_flash_attn: + backends.append(SDPBackend.FLASH_ATTENTION) + + if pytorch_version < (2, 2) or not use_flash_attn: + backends.append(SDPBackend.MATH) + + if ( + SDPBackend.EFFICIENT_ATTENTION in backends and dropout_p > 0.0 + ) and SDPBackend.MATH not in backends: + backends.append(SDPBackend.MATH) + + else: + backends.extend([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]) + + return backends + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] boxes, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__(self, img_paths, image_size, img_mean, img_std, device): + self.img_paths = img_paths + self.image_size = image_size + self.img_mean = img_mean + self.img_std = img_std + self.device = device + # items in `self._images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + img = img.to(self.device) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + images=None, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + device="cpu", +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if images is not None: + images = torch.from_numpy(images).float() + images = rearrange(images, "f h w c -> f c h w") + images = F.interpolate(images, (image_size, image_size), mode="bilinear") + video_height, video_width = images.shape[2:] + else: + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError("Only JPEG frames are supported at this moment") + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + images = images.to(device) + img_mean = img_mean.to(device) + img_std = img_std.to(device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d05cd3e5ebacadca10037c138cbeebfa6b89adab --- /dev/null +++ b/sam2/utils/transforms.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from sam2.utils.misc import get_connected_components + + masks = masks.float() + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + labels, areas = get_connected_components(mask_flat <= self.mask_threshold) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components(mask_flat > self.mask_threshold) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/sam2/utils/visualization.py b/sam2/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5d5460f9466098fc11c54c735d7545e5290184 --- /dev/null +++ b/sam2/utils/visualization.py @@ -0,0 +1,68 @@ +from typing import Optional + +import numpy as np +from PIL import Image + + +def show_masks( + image: np.ndarray, + masks: np.ndarray, + scores: Optional[np.ndarray], + alpha: Optional[float] = 0.5, + display_image: Optional[bool] = False, + only_best: Optional[bool] = True, + autogenerated_mask: Optional[bool] = False, +) -> Image.Image: + if scores is not None: + # sort masks by their scores + sorted_ind = np.argsort(scores)[::-1] + masks = masks[sorted_ind] + + if autogenerated_mask: + masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) + else: + # get mask dimensions + h, w = masks.shape[-2:] + + if display_image: + output_image = Image.fromarray(image) + else: + # create a new blank image to superimpose masks + if autogenerated_mask: + output_image = Image.new( + mode="RGBA", + size=( + masks[0]["segmentation"].shape[1], + masks[0]["segmentation"].shape[0], + ), + color=(0, 0, 0), + ) + else: + output_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0)) + + for i, mask in enumerate(masks): + if not autogenerated_mask: + if mask.ndim > 2: # type: ignore + mask = mask.squeeze() # type: ignore + else: + mask = mask["segmentation"] + # Generate a random color with specified alpha value + color = np.concatenate( + (np.random.randint(0, 256, size=3), [int(alpha * 255)]), axis=0 + ) + + # Create an RGBA image for the mask + mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("L") + mask_colored = Image.new("RGBA", mask_image.size, tuple(color)) + mask_image = Image.composite( + mask_colored, Image.new("RGBA", mask_image.size), mask_image + ) + + # Overlay mask on the output image + output_image = Image.alpha_composite(output_image, mask_image) + + # Exit if specified to only display the best mask + if only_best: + break + + return output_image diff --git a/vace/__init__.py b/vace/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ba686bc2105181216101c39ae4412c247aadbe --- /dev/null +++ b/vace/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +from . import models \ No newline at end of file diff --git a/vace/__pycache__/__init__.cpython-311.pyc b/vace/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2199364bd7b32768f4ce9c31bf173ad7bfe16946 Binary files /dev/null and b/vace/__pycache__/__init__.cpython-311.pyc differ diff --git a/vace/models/__init__.py b/vace/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d99a0a60c32e0c5b194c8bc62125f2ef4ab7e319 --- /dev/null +++ b/vace/models/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +from . import utils + +try: + from . import wan +except ImportError as e: + print("Warning: failed to importing 'wan'. Please install its dependencies with:") + print("pip install wan@git+https://github.com/Wan-Video/Wan2.1") diff --git a/vace/models/__pycache__/__init__.cpython-311.pyc b/vace/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d91315b8bdc5373327ab3018ffab6645e2d8fb38 Binary files /dev/null and b/vace/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/vace/models/utils/__init__.py b/vace/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d95c410fcc2c3e35b127c99e988a00ee1ad85a19 --- /dev/null +++ b/vace/models/utils/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from .preprocessor import VaceVideoProcessor \ No newline at end of file diff --git a/vace/models/utils/__pycache__/__init__.cpython-311.pyc b/vace/models/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd77e5c03f90365cee502e2a0b0ae65bb378676e Binary files /dev/null and b/vace/models/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/vace/models/utils/__pycache__/preprocessor.cpython-311.pyc b/vace/models/utils/__pycache__/preprocessor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d80e69e8b5c58029ae2e81663dd4a0bde8e81fe3 Binary files /dev/null and b/vace/models/utils/__pycache__/preprocessor.cpython-311.pyc differ diff --git a/vace/models/utils/preprocessor.py b/vace/models/utils/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..a0788111a7b79fda3070a2ab8372956c0726af26 --- /dev/null +++ b/vace/models/utils/preprocessor.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF + + +class VaceImageProcessor(object): + def __init__(self, downsample=None, seq_len=None): + self.downsample = downsample + self.seq_len = seq_len + + def _pillow_convert(self, image, cvt_type='RGB'): + if image.mode != cvt_type: + if image.mode == 'P': + image = image.convert(f'{cvt_type}A') + if image.mode == f'{cvt_type}A': + bg = Image.new(cvt_type, + size=(image.width, image.height), + color=(255, 255, 255)) + bg.paste(image, (0, 0), mask=image) + image = bg + else: + image = image.convert(cvt_type) + return image + + def _load_image(self, img_path): + if img_path is None or img_path == '': + return None + img = Image.open(img_path) + img = self._pillow_convert(img) + return img + + def _resize_crop(self, img, oh, ow, normalize=True): + """ + Resize, center crop, convert to tensor, and normalize. + """ + # resize and crop + iw, ih = img.size + if iw != ow or ih != oh: + # resize + scale = max(ow / iw, oh / ih) + img = img.resize( + (round(scale * iw), round(scale * ih)), + resample=Image.Resampling.LANCZOS + ) + assert img.width >= ow and img.height >= oh + + # center crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + + # normalize + if normalize: + img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) + return img + + def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): + return self._resize_crop(img, oh, ow, normalize) + + def load_image(self, data_key, **kwargs): + return self.load_image_batch(data_key, **kwargs) + + def load_image_pair(self, data_key, data_key2, **kwargs): + return self.load_image_batch(data_key, data_key2, **kwargs) + + def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): + seq_len = self.seq_len if seq_len is None else seq_len + imgs = [] + for data_key in data_key_batch: + img = self._load_image(data_key) + imgs.append(img) + w, h = imgs[0].size + dh, dw = self.downsample[1:] + + # compute output size + scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) + oh = int(h * scale) // dh * dh + ow = int(w * scale) // dw * dw + assert (oh // dh) * (ow // dw) <= seq_len + imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] + return *imgs, (oh, ow) + + +class VaceVideoProcessor(object): + def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): + self.downsample = downsample + self.min_area = min_area + self.max_area = max_area + self.min_fps = min_fps + self.max_fps = max_fps + self.zero_start = zero_start + self.keep_last = keep_last + self.seq_len = seq_len + assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) + + def set_area(self, area): + self.min_area = area + self.max_area = area + + def set_seq_len(self, seq_len): + self.seq_len = seq_len + + @staticmethod + def resize_crop(video: torch.Tensor, oh: int, ow: int): + """ + Resize, center crop and normalize for decord loaded video (torch.Tensor type) + + Parameters: + video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) + oh - target height (int) + ow - target width (int) + + Returns: + The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) + + Raises: + """ + # permute ([t, h, w, c] -> [t, c, h, w]) + video = video.permute(0, 3, 1, 2) + + # resize and crop + ih, iw = video.shape[2:] + if ih != oh or iw != ow: + # resize + scale = max(ow / iw, oh / ih) + video = F.interpolate( + video, + size=(round(scale * ih), round(scale * iw)), + mode='bicubic', + antialias=True + ) + assert video.size(3) >= ow and video.size(2) >= oh + + # center crop + x1 = (video.size(3) - ow) // 2 + y1 = (video.size(2) - oh) // 2 + video = video[:, :, y1:y1 + oh, x1:x1 + ow] + + # permute ([t, c, h, w] -> [c, t, h, w]) and normalize + video = video.transpose(0, 1).float().div_(127.5).sub_(1.) + return video + + def _video_preprocess(self, video, oh, ow): + return self.resize_crop(video, oh, ow) + + def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): + target_fps = min(fps, self.max_fps) + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (int(duration * target_fps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = of / target_fps + begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) + timestamps = np.linspace(begin, begin + target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] < frame_timestamps[None, :, 1] + ), axis=1).tolist() + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (len(frame_timestamps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = duration + target_fps = of / target_duration + timestamps = np.linspace(0., target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] <= frame_timestamps[None, :, 1] + ), axis=1).tolist() + # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + + def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): + if self.keep_last: + return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) + else: + return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) + + def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): + rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) + # read video + import decord + decord.bridge.set_bridge('torch') + readers = [] + for data_k in data_key_batch: + reader = decord.VideoReader(data_k) + readers.append(reader) + + fps = readers[0].get_avg_fps() + length = min([len(r) for r in readers]) + frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] + frame_timestamps = np.array(frame_timestamps, dtype=np.float32) + h, w = readers[0].next().shape[:2] + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) + + # preprocess video + videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] + videos = [self._video_preprocess(video, oh, ow) for video in videos] + return *videos, frame_ids, (oh, ow), fps + # return videos if len(videos) > 1 else videos[0] + + +def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_video is None and sub_src_mask is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + for j, ref_img in enumerate(ref_images): + if ref_img is not None and ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + src_ref_images[i][j] = white_canvas + return src_video, src_mask, src_ref_images diff --git a/vace/models/wan/__init__.py b/vace/models/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c9f319186abd76a97bd448b6ceb57564e117c80 --- /dev/null +++ b/vace/models/wan/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from . import modules +from .wan_vace import WanVace diff --git a/vace/models/wan/__pycache__/__init__.cpython-311.pyc b/vace/models/wan/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36e5b7bd26c4a6d3fee9b38652f21300293b1155 Binary files /dev/null and b/vace/models/wan/__pycache__/__init__.cpython-311.pyc differ diff --git a/vace/models/wan/__pycache__/wan_vace.cpython-311.pyc b/vace/models/wan/__pycache__/wan_vace.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95b528699128b5584f2e2351a7801e1b61521b54 Binary files /dev/null and b/vace/models/wan/__pycache__/wan_vace.cpython-311.pyc differ diff --git a/vace/models/wan/configs/__init__.py b/vace/models/wan/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46ebd0dd2a619b230a7f7fd6e36326be020c7882 --- /dev/null +++ b/vace/models/wan/configs/__init__.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + +from .wan_t2v_1_3B import t2v_1_3B +from .wan_t2v_14B import t2v_14B + +WAN_CONFIGS = { + 'vace-1.3B': t2v_1_3B, + 'vace-14B': t2v_14B, +} + +SIZE_CONFIGS = { + '720*1280': (720, 1280), + '1280*720': (1280, 720), + '480*832': (480, 832), + '832*480': (832, 480), + '1024*1024': (1024, 1024), + '720p': (1280, 720), + '480p': (480, 832) +} + +MAX_AREA_CONFIGS = { + '720*1280': 720 * 1280, + '1280*720': 1280 * 720, + '480*832': 480 * 832, + '832*480': 832 * 480, + '720p': 1280 * 720, + '480p': 480 * 832 +} + +SUPPORTED_SIZES = { + 'vace-1.3B': ('480*832', '832*480', '480p'), + 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480', '480p', '720p') +} diff --git a/vace/models/wan/configs/__pycache__/__init__.cpython-311.pyc b/vace/models/wan/configs/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b7140198e5b3ae3f3db9401912d2d2742214e7f Binary files /dev/null and b/vace/models/wan/configs/__pycache__/__init__.cpython-311.pyc differ diff --git a/vace/models/wan/configs/__pycache__/shared_config.cpython-311.pyc b/vace/models/wan/configs/__pycache__/shared_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c21b0fe4f8c9717b4ab76019edbb79049d21cb4 Binary files /dev/null and b/vace/models/wan/configs/__pycache__/shared_config.cpython-311.pyc differ diff --git a/vace/models/wan/configs/__pycache__/wan_t2v_14B.cpython-311.pyc b/vace/models/wan/configs/__pycache__/wan_t2v_14B.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93579703336219d5add35a27fb558aa28573c73 Binary files /dev/null and b/vace/models/wan/configs/__pycache__/wan_t2v_14B.cpython-311.pyc differ diff --git a/vace/models/wan/configs/__pycache__/wan_t2v_1_3B.cpython-311.pyc b/vace/models/wan/configs/__pycache__/wan_t2v_1_3B.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cb7d6a10dd3636109d5705516e6973ebc64d77c Binary files /dev/null and b/vace/models/wan/configs/__pycache__/wan_t2v_1_3B.cpython-311.pyc differ diff --git a/vace/models/wan/configs/shared_config.py b/vace/models/wan/configs/shared_config.py new file mode 100644 index 0000000000000000000000000000000000000000..74e5cef2540f4ffb00dab185c93bd42b30b89989 --- /dev/null +++ b/vace/models/wan/configs/shared_config.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +from easydict import EasyDict + +#------------------------ Wan shared config ------------------------# +wan_shared_cfg = EasyDict() + +# t5 +wan_shared_cfg.t5_model = 'umt5_xxl' +wan_shared_cfg.t5_dtype = torch.bfloat16 +wan_shared_cfg.text_len = 512 + +# transformer +wan_shared_cfg.param_dtype = torch.float16 + +# inference +wan_shared_cfg.num_train_timesteps = 1000 +wan_shared_cfg.sample_fps = 16 +wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' diff --git a/vace/models/wan/configs/wan_t2v_14B.py b/vace/models/wan/configs/wan_t2v_14B.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0ee69dea796bfd6eccdedf4ec04835086227a6 --- /dev/null +++ b/vace/models/wan/configs/wan_t2v_14B.py @@ -0,0 +1,29 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V 14B ------------------------# + +t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') +t2v_14B.update(wan_shared_cfg) + +# t5 +t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_14B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_14B.vae_stride = (4, 8, 8) + +# transformer +t2v_14B.patch_size = (1, 2, 2) +t2v_14B.dim = 5120 +t2v_14B.ffn_dim = 13824 +t2v_14B.freq_dim = 256 +t2v_14B.num_heads = 40 +t2v_14B.num_layers = 40 +t2v_14B.window_size = (-1, -1) +t2v_14B.qk_norm = True +t2v_14B.cross_attn_norm = True +t2v_14B.eps = 1e-6 diff --git a/vace/models/wan/configs/wan_t2v_1_3B.py b/vace/models/wan/configs/wan_t2v_1_3B.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd1464ec010e7bf2570e4375e2814a0943a189a --- /dev/null +++ b/vace/models/wan/configs/wan_t2v_1_3B.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V 1.3B ------------------------# + +t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') +t2v_1_3B.update(wan_shared_cfg) + +# t5 +t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_1_3B.vae_stride = (4, 8, 8) + +# transformer +t2v_1_3B.patch_size = (1, 2, 2) +t2v_1_3B.dim = 1536 +t2v_1_3B.ffn_dim = 8960 +t2v_1_3B.freq_dim = 256 +t2v_1_3B.num_heads = 12 +t2v_1_3B.num_layers = 30 +t2v_1_3B.window_size = (-1, -1) +t2v_1_3B.qk_norm = True +t2v_1_3B.cross_attn_norm = True +t2v_1_3B.eps = 1e-6 diff --git a/vace/models/wan/distributed/__init__.py b/vace/models/wan/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13a6b25acc146323b5d4769bf7ed6abc3b5d7d68 --- /dev/null +++ b/vace/models/wan/distributed/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from .xdit_context_parallel import pad_freqs, rope_apply, usp_dit_forward_vace, usp_dit_forward, usp_attn_forward \ No newline at end of file diff --git a/vace/models/wan/distributed/xdit_context_parallel.py b/vace/models/wan/distributed/xdit_context_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..aacf47fa041966d857a3c68bed48e9363375802d --- /dev/null +++ b/vace/models/wan/distributed/xdit_context_parallel.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.cuda.amp as amp +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +from ..modules.model import sinusoidal_embedding_1d + + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + """ + x: [B, L, N, C]. + grid_sizes: [B, 3]. + freqs: [M, C // 2]. + """ + s, n, c = x.size(1), x.size(2), x.size(3) // 2 + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( + s, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs_i = pad_freqs(freqs_i, s * sp_size) + s_per_rank = s + freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * + s_per_rank), :, :] + x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +def usp_dit_forward_vace( + self, + x, + vace_context, + seq_len, + kwargs +): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + # Context Parallel + c = torch.chunk( + c, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + for block in self.vace_blocks: + c = block(c, **new_kwargs) + hints = torch.unbind(c)[:-1] + return hints + + +def usp_dit_forward( + self, + x, + t, + vace_context, + context, + seq_len, + vace_context_scale=1.0, + clip_fea=None, + y=None, +): + """ + x: A list of videos each with shape [C, T, H, W]. + t: [B]. + context: A list of text embeddings each with shape [L, C]. + """ + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + # if y is not None: + # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + # if clip_fea is not None: + # context_clip = self.img_emb(clip_fea) # bs x 257 x dim + # context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) + + # Context Parallel + x = torch.chunk( + x, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + hints = self.forward_vace(x, vace_context, seq_len, kwargs) + kwargs['hints'] = hints + kwargs['context_scale'] = vace_context_scale + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + +def usp_attn_forward(self, + x, + seq_lens, + grid_sizes, + freqs, + dtype=torch.bfloat16): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + + # TODO: We should use unpaded q,k,v for attention. + # k_lens = seq_lens // get_sequence_parallel_world_size() + # if k_lens is not None: + # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) + # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) + # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) + + x = xFuserLongContextAttention()( + None, + query=half(q), + key=half(k), + value=half(v), + window_size=self.window_size) + + # TODO: padding after attention. + # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) + + # output + x = x.flatten(2) + x = self.o(x) + return x diff --git a/vace/models/wan/modules/__init__.py b/vace/models/wan/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..307c3dd672ce42d271d0bff77a43c071aa32e271 --- /dev/null +++ b/vace/models/wan/modules/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from .model import VaceWanAttentionBlock, BaseWanAttentionBlock, VaceWanModel \ No newline at end of file diff --git a/vace/models/wan/modules/__pycache__/__init__.cpython-311.pyc b/vace/models/wan/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32b3118f577a488e5f10f8805df2c06d90624bc8 Binary files /dev/null and b/vace/models/wan/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/vace/models/wan/modules/__pycache__/mm_attention.cpython-311.pyc b/vace/models/wan/modules/__pycache__/mm_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f7a8f2e7fb91c94bace996e2a6e29e9e47b8027 Binary files /dev/null and b/vace/models/wan/modules/__pycache__/mm_attention.cpython-311.pyc differ diff --git a/vace/models/wan/modules/__pycache__/model.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f151dd6b1295f3fae31d4473b0e055c8fc722ae2 Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model.cpython-311.pyc differ diff --git a/vace/models/wan/modules/__pycache__/model_mm.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model_mm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..076e90fb3a6da535a99f30da7be29b566600f1ec Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model_mm.cpython-311.pyc differ diff --git a/vace/models/wan/modules/__pycache__/model_ruan.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model_ruan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..616e207ce99e05d1b4676975a9353f025620880e Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model_ruan.cpython-311.pyc differ diff --git a/vace/models/wan/modules/__pycache__/model_tr.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model_tr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82183bc19c0ae07232079ad125283a0b82d4f158 Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model_tr.cpython-311.pyc differ diff --git a/vace/models/wan/modules/__pycache__/model_wan.cpython-311.pyc b/vace/models/wan/modules/__pycache__/model_wan.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..104ca0acabd89f205e09b768bf99c65df0cfa1d5 Binary files /dev/null and b/vace/models/wan/modules/__pycache__/model_wan.cpython-311.pyc differ diff --git a/vace/models/wan/modules/mm_attention.py b/vace/models/wan/modules/mm_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..84f2e0bcb055bc8bb26ed024c5335a6399e30a7d --- /dev/null +++ b/vace/models/wan/modules/mm_attention.py @@ -0,0 +1,747 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + +from wan.modules.attention import flash_attention +from wan.modules.model import WanSelfAttention, WanAttentionBlock + +__all__ = ['WanModel'] + +T5_CONTEXT_TOKEN_NUMBER = 512 +FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast(enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@amp.autocast(enabled=False) +def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( + seq_len, n, -1, 2)) + freqs_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +class WanRMSNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +# class MMSelfAttention(nn.Module): + +# def __init__(self, +# dim, +# num_heads, +# window_size=(-1, -1), +# qk_norm=True, +# eps=1e-6): +# assert dim % num_heads == 0 +# super().__init__() +# self.dim = dim +# self.num_heads = num_heads +# self.head_dim = dim // num_heads +# self.window_size = window_size +# self.qk_norm = qk_norm +# self.eps = eps + +# # layers +# self.q = nn.Linear(dim, dim) +# self.k = nn.Linear(dim, dim) +# self.v = nn.Linear(dim, dim) +# self.o = nn.Linear(dim, dim) +# self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() +# self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + +# def forward(self, x, seq_lens, grid_sizes, grid_sizes_ref, freqs, len1): +# r""" +# Args: +# x(Tensor): Shape [B, L, num_heads, C / num_heads] +# seq_lens(Tensor): Shape [B] +# grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) +# freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] +# """ +# b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + +# # query, key, value function +# def qkv_fn(x): +# q = self.norm_q(self.q(x)).view(b, s, n, d) +# k = self.norm_k(self.k(x)).view(b, s, n, d) +# v = self.v(x).view(b, s, n, d) +# return q, k, v + +# q, k, v = qkv_fn(x) +# q[:, :len1] = rope_apply(q[:, :len1], grid_sizes, freqs) +# k[:, :len1] = rope_apply(k[:, :len1], grid_sizes, freqs) +# q[:, len1:] = rope_apply(q[:, len1:], grid_sizes_ref, freqs) +# k[:, len1:] = rope_apply(k[:, len1:], grid_sizes_ref, freqs) + +# x = flash_attention( +# q=q, +# k=k, +# v=v, +# k_lens=seq_lens, +# window_size=self.window_size) + +# # output +# x = x.flatten(2) +# x = self.o(x) +# return x + +class MMSelfAttention(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + self.q_ref = nn.Linear(dim, dim) + self.k_ref = nn.Linear(dim, dim) + self.v_ref = nn.Linear(dim, dim) + self.norm_q_ref = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k_ref = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, x_ref, seq_lens, grid_sizes, grid_sizes_ref, freqs, freqs_ref, len1): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + b_ref, s_ref = x_ref.shape[:2] + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + def qkv_fn_ref(x_ref): + q_ref = self.norm_q_ref(self.q_ref(x_ref)).view(b_ref, s_ref, n, d) + k_ref = self.norm_k_ref(self.k_ref(x_ref)).view(b_ref, s_ref, n, d) + v_ref = self.v_ref(x_ref).view(b_ref, s_ref, n, d) + return q_ref, k_ref, v_ref + + q, k, v = qkv_fn(x) + q_ref, k_ref, v_ref = qkv_fn_ref(x_ref) + q = torch.cat([q_ref, q], dim=1) + k = torch.cat([k_ref, k], dim=1) + v = torch.cat([v_ref, v], dim=1) + + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + x = flash_attention( + q=q, + k=k, + v=v, + k_lens=seq_lens, + window_size=self.window_size) + + # output + x = x.flatten(2) + x = self.o(x) + return x + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER + context_img = context[:, :image_context_length] + context = context[:, image_context_length:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + img_x = flash_attention(q, k_img, v_img, k_lens=None) + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + 't2v_cross_attn': WanT2VCrossAttention, + 'i2v_cross_attn': WanI2VCrossAttention, +} + + +class MMAttentionBlock(nn.Module): + + def __init__(self, + cross_attn_type, + dim, + dim_ref, + ffn_dim, + ffn_dim_ref, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6): + super().__init__() + self.dim = dim + self.dim_ref = dim_ref + self.ffn_dim = ffn_dim + self.ffn_dim_ref = ffn_dim_ref + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = MMSelfAttention(dim, num_heads, window_size, qk_norm, + eps) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, + num_heads, + (-1, -1), + qk_norm, + eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + self.norm1_ref = WanLayerNorm(dim_ref, eps) + self.norm3_ref = WanLayerNorm( + dim_ref, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.norm2_ref = WanLayerNorm(dim_ref, eps) + self.ffn_ref = nn.Sequential( + nn.Linear(dim_ref, ffn_dim_ref), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim_ref, dim_ref)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.modulation_ref = nn.Parameter(torch.randn(1, 6, dim_ref) / dim_ref**0.5) + + def forward( + self, + x, + x_ref, + e, + e_ref, + seq_lens, + seq_lens_ref, + grid_sizes, + grid_sizes_ref, + freqs, + freqs_ref, + context, + context_lens, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 and e_ref.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + e).chunk(6, dim=1) + e_ref = (self.modulation_ref + e_ref).chunk(6, dim=1) + assert e[0].dtype == torch.float32 and e_ref[0].dtype == torch.float32 + + len1, len2 = x.shape[1], x_ref.shape[1] + + # self-attention + y = self.self_attn( + self.norm1(x).float() * (1 + e[1]) + e[0], self.norm1_ref(x_ref).float() * (1 + e_ref[1]) + e_ref[0], + seq_lens, grid_sizes, grid_sizes_ref, freqs, freqs_ref, len1) + y_ref, y = y[:, :len2], y[:, len2:] + with amp.autocast(dtype=torch.float32): + x = x + y * e[2] + x_ref = x_ref + y_ref * e_ref[2] + + # cross-attention & ffn function + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) + with amp.autocast(dtype=torch.float32): + x = x + y * e[5] + + y_ref = self.ffn_ref(self.norm2_ref(x_ref).float() * (1 + e_ref[4]) + e_ref[3]) + with amp.autocast(dtype=torch.float32): + x_ref = x_ref + y_ref * e_ref[5] + return x, x_ref + + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + with amp.autocast(dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class MLPProj(torch.nn.Module): + + def __init__(self, in_dim, out_dim, flf_pos_emb=False): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim)) + if flf_pos_emb: # NOTE: we only use this for `flf2v` + self.emb_pos = nn.Parameter( + torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) + + def forward(self, image_embeds): + if hasattr(self, 'emb_pos'): + bs, n, d = image_embeds.shape + image_embeds = image_embeds.view(-1, 2 * n, d) + image_embeds = image_embeds + self.emb_pos + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class MMModel(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = [ + 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' + ] + _no_split_modules = ['WanAttentionBlock'] + + @register_to_config + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + dim_ref=2048, + ffn_dim=8192, + ffn_dim_ref=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace' + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ['t2v', 'i2v', 'flf2v', 'vace'] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.dim_ref = dim_ref + self.ffn_dim = ffn_dim + self.ffn_dim_ref = ffn_dim_ref + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), + nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + self.time_embedding_ref = nn.Sequential( + nn.Linear(freq_dim, dim_ref), nn.SiLU(), nn.Linear(dim_ref, dim_ref)) + self.time_projection_ref = nn.Sequential(nn.SiLU(), nn.Linear(dim_ref, dim_ref * 6)) + + # blocks + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + + self.blocks = nn.ModuleList([ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ]) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat([ + rope_params(1024, d - 4 * (d // 6)), + rope_params(1024, 2 * (d // 6)), + rope_params(1024, 2 * (d // 6)) + ], + dim=1) + self.freqs_ref = torch.cat([ + rope_params(1024, d - 4 * (d // 6), 5000), + rope_params(1024, 2 * (d // 6), 5000), + rope_params(1024, 2 * (d // 6), 5000) + ], + dim=1) + + if model_type == 'i2v' or model_type == 'flf2v': + self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') + + # initialize weights + self.init_weights() + + def forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode or first-last-frame-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == 'i2v' or self.model_type == 'flf2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/vace/models/wan/modules/model.py b/vace/models/wan/modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..31efed7b9fa90a5061263f30f74dd3da79f2a127 --- /dev/null +++ b/vace/models/wan/modules/model.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from wan.modules.model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d +import torch.utils.checkpoint + +def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward + + +def gradient_checkpoint_forward( + model, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + *args, + **kwargs, +): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + else: + model_output = model(*args, **kwargs) + return model_output + + +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0 + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.after_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class BaseWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=None + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + + def forward(self, x, hints, context_scale=1.0, **kwargs): + x = super().forward(x, **kwargs) + if self.block_id is not None: + x = x + hints[self.block_id] * context_scale + return x + + +class VaceWanModel(WanModel): + @register_to_config + def __init__(self, + vace_layers=None, + vace_in_dim=None, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6): + model_type = "t2v" # TODO: Hard code for both preview and official versions. + super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, + num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) + + self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers + self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim + + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # blocks + self.blocks = nn.ModuleList([ + BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, + block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) + for i in range(self.num_layers) + ]) + + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + def forward_vace( + self, + x, + vace_context, + seq_len, + kwargs, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for block in self.vace_blocks: + c = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + c, + **new_kwargs, + ) + hints = torch.unbind(c)[:-1] + return hints + + def forward( + self, + x, + t, + vace_context, + context, + seq_len, + vace_context_scale=1.0, + clip_fea=None, + y=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + ): + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # text embedding + context_lens = None + context = self.text_embedding(context) + + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + ) + + hints = self.forward_vace( + x, vace_context, seq_len, kwargs, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + kwargs['hints'] = hints + kwargs['context_scale'] = vace_context_scale + + for block in self.blocks: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, + **kwargs, + ) + + x = self.head(x, e) + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] diff --git a/vace/models/wan/modules/model_mm.py b/vace/models/wan/modules/model_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..bc531eafbb31456ef909ad1b4db2d4b2dc5477de --- /dev/null +++ b/vace/models/wan/modules/model_mm.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from wan.modules.model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d +from .mm_attention import MMModel, MMAttentionBlock +import torch.utils.checkpoint + +def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward + + +def gradient_checkpoint_forward( + model, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + *args, + **kwargs, +): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + else: + model_output = model(*args, **kwargs) + return model_output + + +class VaceWanAttentionBlock(MMAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + dim_ref, + ffn_dim, + ffn_dim_ref, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0 + ): + super().__init__(cross_attn_type, dim, dim_ref, ffn_dim, ffn_dim_ref, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.before_proj_ref = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj_ref.weight) + nn.init.zeros_(self.before_proj_ref.bias) + self.after_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def forward(self, c, c_ref, x, x_ref, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c_ref = self.before_proj_ref(c_ref) + x_ref + c, c_ref = super().forward(c, c_ref, **kwargs) + c_skip = self.after_proj(torch.cat([c_ref, c], dim=1)) + return c, c_ref, c_skip + + +class BaseWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=None + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + + def forward(self, x, hints, context_scale=1.0, **kwargs): + x = super().forward(x, **kwargs) + if self.block_id is not None: + x = x + hints[self.block_id] * context_scale + return x + + +class VaceMMModel(MMModel): + @register_to_config + def __init__(self, + vace_layers=None, + vace_in_dim=None, + ref_in_dim=None, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + dim_ref=2048, + ffn_dim=8192, + ffn_dim_ref=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6): + model_type = "t2v" # TODO: Hard code for both preview and official versions. + super().__init__(model_type, patch_size, text_len, in_dim, dim, dim_ref, ffn_dim, ffn_dim_ref, freq_dim, text_dim, out_dim, + num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) + + self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers + self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim + self.ref_in_dim = ref_in_dim + + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # blocks + self.blocks = nn.ModuleList([ + BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, + block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) + for i in range(self.num_layers) + ]) + + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.dim_ref, self.ffn_dim, self.ffn_dim_ref, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size + ) + self.ref_patch_embedding = nn.Conv3d( + self.ref_in_dim, self.dim_ref, kernel_size=self.patch_size, stride=self.patch_size + ) + + def forward_vace( + self, + x, + x_ref, + vace_context, + ref_context, + seq_len, + seq_len_ref, + kwargs, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + # embeddings + c = self.vace_patch_embedding(vace_context) + c = c.flatten(2).transpose(1, 2) + # c = torch.cat([ + # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + # dim=1) for u in c + # ]) + + c_ref = self.ref_patch_embedding(ref_context) + c_ref = c_ref.flatten(2).transpose(1, 2) + # c_ref = torch.cat([ + # torch.cat([u, u.new_zeros(1, seq_len_ref - u.size(1), u.size(2))], + # dim=1) for u in c_ref + # ]) + + new_kwargs = dict(x=x, x_ref=x_ref) + new_kwargs.update(kwargs) + hints = [] + for block in self.vace_blocks: + c, c_ref, c_skip = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + c, + c_ref, + **new_kwargs, + ) + if c_skip is not None: + hints.append(c_skip) + return hints + + def forward( + self, + x, + t, + vace_context, + ref_context, + context, + seq_len, + seq_len_ref, + vace_context_scale=1.0, + clip_fea=None, + y=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + ): + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + self.freqs_ref = self.freqs_ref.to(device) + + # embeddings + + grid_sizes = torch.tensor([x.shape[2]/self.patch_size[0], x.shape[3]/self.patch_size[1], x.shape[4]/self.patch_size[2]]).long().unsqueeze(0).repeat(x.shape[0], 1) + seq_lens = torch.tensor( + (x.shape[2]//self.patch_size[0]) + * (x.shape[3]//self.patch_size[1]) + * (x.shape[4]//self.patch_size[2]) + ).repeat(x.shape[0]).long() + x_ref = self.patch_embedding(x[:, :, :1, :, :]) + x_ref = x_ref.flatten(2).transpose(1, 2) + x_vid = self.patch_embedding(x[:, :, 1:, :, :]) + x_vid = x_vid.flatten(2).transpose(1, 2) + # print(x.dtype) + grid_sizes_ref = torch.tensor([ref_context.shape[2]/self.patch_size[0], ref_context.shape[3]/self.patch_size[1], ref_context.shape[4]/self.patch_size[2]]).long().unsqueeze(0).repeat(grid_sizes.shape[0], 1) + + seq_lens_ref = torch.tensor( + (ref_context.shape[2]//self.patch_size[0]) + * (ref_context.shape[3]//self.patch_size[1]) + * (ref_context.shape[4]//self.patch_size[2]) + ).repeat(x.shape[0]).long() + assert seq_lens.max() <= seq_len + # x = torch.cat([ + # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + # dim=1) for u in x + # ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + e_ref = self.time_embedding_ref( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0_ref = self.time_projection_ref(e_ref).unflatten(1, (6, self.dim)) + + # text embedding + context_lens = None + context = self.text_embedding(context) + + kwargs = dict( + e=e0, + e_ref=e0_ref, + seq_lens=seq_lens, + seq_lens_ref=seq_lens_ref, + grid_sizes=grid_sizes, + grid_sizes_ref=grid_sizes_ref, + freqs=self.freqs, + freqs_ref=self.freqs_ref, + context=context, + context_lens=context_lens, + ) + # ========== 支路 VACE ========== + hints = self.forward_vace( + x_vid, x_ref, vace_context, ref_context, seq_len, seq_len_ref, kwargs, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + kwargs['hints'] = hints + kwargs['context_scale'] = vace_context_scale + kwargs.pop("e_ref") + kwargs.pop("seq_lens_ref") + kwargs.pop("grid_sizes_ref") + kwargs.pop("freqs_ref") + + # ========== 主路 BLOCKS ========== + x = torch.cat([x_ref, x_vid], dim=1) + for block in self.blocks: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, + **kwargs, + ) + + x = self.head(x, e) + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] diff --git a/vace/models/wan/modules/model_tr.py b/vace/models/wan/modules/model_tr.py new file mode 100644 index 0000000000000000000000000000000000000000..abbfc02a84c47e37fa339970db8e291ec10ec692 --- /dev/null +++ b/vace/models/wan/modules/model_tr.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from wan.modules.model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d + + +def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + + return custom_forward + + +def gradient_checkpoint_forward( + model, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + *args, + **kwargs, +): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + model_output = torch.utils.checkpoint.checkpoint( + create_custom_forward(model), + *args, + **kwargs, + use_reentrant=False, + ) + else: + model_output = model(*args, **kwargs) + return model_output + + +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0, + ): + super().__init__( + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size, + qk_norm, + cross_attn_norm, + eps, + ) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.after_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class BaseWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=None, + ): + super().__init__( + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size, + qk_norm, + cross_attn_norm, + eps, + ) + self.block_id = block_id + + def forward(self, x, hints, context_scale=1.0, **kwargs): + x = super().forward(x, **kwargs) + if self.block_id is not None: + x = x + hints[self.block_id] * context_scale + return x + + +class VaceWanModel(WanModel): + @register_to_config + def __init__( + self, + vace_layers=None, + vace_in_dim=None, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + activation_checkpointing=False, + eps=1e-6, + ): + super().__init__( + model_type, + patch_size, + text_len, + in_dim, + dim, + ffn_dim, + freq_dim, + text_dim, + out_dim, + num_heads, + num_layers, + window_size, + qk_norm, + cross_attn_norm, + eps, + ) + + self.vace_layers = ( + [i for i in range(0, self.num_layers, 2)] + if vace_layers is None + else vace_layers + ) + self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim + + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # blocks + self.blocks = nn.ModuleList( + [ + BaseWanAttentionBlock( + "t2v_cross_attn", + self.dim, + self.ffn_dim, + self.num_heads, + self.window_size, + self.qk_norm, + self.cross_attn_norm, + self.eps, + block_id=( + self.vace_layers_mapping[i] if i in self.vace_layers else None + ), + ) + for i in range(self.num_layers) + ] + ) + + # vace blocks + self.vace_blocks = nn.ModuleList( + [ + VaceWanAttentionBlock( + "t2v_cross_attn", + self.dim, + self.ffn_dim, + self.num_heads, + self.window_size, + self.qk_norm, + self.cross_attn_norm, + self.eps, + block_id=i, + ) + for i in self.vace_layers + ] + ) + + # vace patch embeddings + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_dim, + self.dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + + def forward_vace(self, x, vace_context, seq_len, kwargs): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat( + [ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in c + ] + ) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for block in self.vace_blocks: + c = block(c, **new_kwargs) + # for block in self.vace_blocks: + # c = gradient_checkpoint_forward( + # block,s + # True, + # False, + # c, + # **new_kwargs, + # ) + + hints = torch.unbind(c)[:-1] + return hints + + def forward( + self, + x, + t, + vace_context, + context, + seq_len, + vace_context_scale=1.0, + clip_fea=None, + use_gradient_checkpointing=False, + y=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + # if self.model_type == 'i2v': + # assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + # if y is not None: + # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x] + ) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + # print(seq_lens) + # print(seq_len) + assert seq_lens.max() <= seq_len + x = torch.cat( + [ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) + for u in x + ] + ) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + # context = self.text_embedding(context) + + # if clip_fea is not None: + # context_clip = self.img_emb(clip_fea) # bs x 257 x dim + # context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + ) + + hints = self.forward_vace(x, vace_context, seq_len, kwargs) + kwargs["hints"] = hints + kwargs["context_scale"] = vace_context_scale + + # for block in self.blocks: + + # if self.activation_checkpointing: + # ### Some bug here with deepspeed zero + # # print("use activation checkpointing") + # # x = x.requires_grad_() + # x = torch.utils.checkpoint.checkpoint( + # lambda inp: block(inp, **kwargs), x, use_reentrant=False + # ) + # else: + # x = block(x, **kwargs) + + for block in self.blocks: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + False, + x, + **kwargs, + ) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + # return [u.float() for u in x] + # print(x.shape) + return x diff --git a/vace/models/wan/wan_vace.py b/vace/models/wan/wan_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..d388c5073c28d13fbea987a2c3cdad8ae703dfe8 --- /dev/null +++ b/vace/models/wan/wan_vace.py @@ -0,0 +1,719 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import sys +import gc +import math +import time +import random +import types +import logging +import traceback +from contextlib import contextmanager +from functools import partial + +from PIL import Image +import torchvision.transforms.functional as TF +import torch +import torch.nn.functional as F +import torch.cuda.amp as amp +import torch.distributed as dist +import torch.multiprocessing as mp +from tqdm import tqdm + +from wan.text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler) +from .modules.model import VaceWanModel +from ..utils.preprocessor import VaceVideoProcessor + + +class WanVace(WanT2V): + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + t5_cpu=False, + ): + r""" + Initializes the Wan text-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_usp (`bool`, *optional*, defaults to False): + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + logging.info(f"Creating VaceWanModel from {checkpoint_dir}") + self.model = VaceWanModel.from_pretrained(checkpoint_dir) + self.model.eval().requires_grad_(False) + + if use_usp: + from xfuser.core.distributed import \ + get_sequence_parallel_world_size + + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace) + for block in self.model.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + for block in self.model.vace_blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + self.model.forward = types.MethodType(usp_dit_forward, self.model) + self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model) + self.sp_size = get_sequence_parallel_world_size() + else: + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + if dit_fsdp: + self.model = shard_fn(self.model) + else: + self.model.to(self.device) + + self.sample_neg_prompt = config.sample_neg_prompt + + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), + min_area=480 * 832, + max_area=480 * 832, + min_fps=self.config.sample_fps, + max_fps=self.config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): + vae = self.vae if vae is None else vae + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames) + else: + masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive) + reactive = vae.encode(reactive) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs) + else: + ref_latent = vae.encode(refs) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None, vae_stride=None): + vae_stride = self.vae_stride if vae_stride is None else vae_stride + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device): + area = image_size[0] * image_size[1] + self.vid_proc.set_area(area) + if area == 720*1280: + self.vid_proc.set_seq_len(75600) + elif area == 480*832: + self.vid_proc.set_seq_len(32760) + else: + raise NotImplementedError(f'image_size {image_size} is not supported') + + image_size = (image_size[1], image_size[0]) + image_sizes = [] + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(src_video[i].shape[2:]) + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = Image.open(ref_img).convert("RGB") + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + def decode_latent(self, zs, ref_images=None, vae=None): + vae = self.vae if vae is None else vae + if ref_images is None: + ref_images = [None] * len(zs) + else: + assert len(zs) == len(ref_images) + + trimed_zs = [] + for z, refs in zip(zs, ref_images): + if refs is not None: + z = z[:, len(refs):, :, :] + trimed_zs.append(z) + + return vae.decode(trimed_zs) + + + + def generate(self, + input_prompt, + input_frames, + input_masks, + input_ref_images, + size=(1280, 720), + frame_num=81, + context_scale=1.0, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + # F = frame_num + # target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + # size[1] // self.vae_stride[1], + # size[0] // self.vae_stride[2]) + # + # seq_len = math.ceil((target_shape[2] * target_shape[3]) / + # (self.patch_size[1] * self.patch_size[2]) * + # target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + # vace context encode + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks) + m0 = self.vace_encode_masks(input_masks, input_ref_images) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ] + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0] + noise_pred_uncond = self.model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.decode_latent(x0, input_ref_images) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + + +class WanVaceMP(WanVace): + def __init__( + self, + config, + checkpoint_dir, + use_usp=False, + ulysses_size=None, + ring_size=None + ): + self.config = config + self.checkpoint_dir = checkpoint_dir + self.use_usp = use_usp + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12345' + os.environ['RANK'] = '0' + os.environ['WORLD_SIZE'] = '1' + self.in_q_list = None + self.out_q = None + self.inference_pids = None + self.ulysses_size = ulysses_size + self.ring_size = ring_size + self.dynamic_load() + + self.device = 'cpu' if torch.cuda.is_available() else 'cpu' + self.vid_proc = VaceVideoProcessor( + downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]), + min_area=720 * 1280, + max_area=720 * 1280, + min_fps=config.sample_fps, + max_fps=config.sample_fps, + zero_start=True, + seq_len=75600, + keep_last=True) + + + def dynamic_load(self): + if hasattr(self, 'inference_pids') and self.inference_pids is not None: + return + gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count() + pmi_rank = int(os.environ['RANK']) + pmi_world_size = int(os.environ['WORLD_SIZE']) + in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)] + out_q = torch.multiprocessing.Manager().Queue() + initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)] + context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False) + all_initialized = False + while not all_initialized: + all_initialized = all(event.is_set() for event in initialized_events) + if not all_initialized: + time.sleep(0.1) + print('Inference model is initialized', flush=True) + self.in_q_list = in_q_list + self.out_q = out_q + self.inference_pids = context.pids() + self.initialized_events = initialized_events + + def transfer_data_to_cuda(self, data, device): + if data is None: + return None + else: + if isinstance(data, torch.Tensor): + data = data.to(device) + elif isinstance(data, list): + data = [self.transfer_data_to_cuda(subdata, device) for subdata in data] + elif isinstance(data, dict): + data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()} + return data + + def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env): + try: + world_size = pmi_world_size * gpu_infer + rank = pmi_rank * gpu_infer + gpu + print("world_size", world_size, "rank", rank, flush=True) + + torch.cuda.set_device(gpu) + dist.init_process_group( + backend='nccl', + init_method='env://', + rank=rank, + world_size=world_size + ) + + from xfuser.core.distributed import (initialize_model_parallel, + init_distributed_environment) + init_distributed_environment( + rank=dist.get_rank(), world_size=dist.get_world_size()) + + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=self.ring_size or 1, + ulysses_degree=self.ulysses_size or 1 + ) + + num_train_timesteps = self.config.num_train_timesteps + param_dtype = self.config.param_dtype + shard_fn = partial(shard_model, device_id=gpu) + text_encoder = T5EncoderModel( + text_len=self.config.text_len, + dtype=self.config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint), + tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer), + shard_fn=shard_fn if True else None) + text_encoder.model.to(gpu) + vae_stride = self.config.vae_stride + patch_size = self.config.patch_size + vae = WanVAE( + vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint), + device=gpu) + logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}") + model = VaceWanModel.from_pretrained(self.checkpoint_dir) + model.eval().requires_grad_(False) + + if self.use_usp: + from xfuser.core.distributed import get_sequence_parallel_world_size + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace) + for block in model.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + for block in model.vace_blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + model.forward = types.MethodType(usp_dit_forward, model) + model.forward_vace = types.MethodType(usp_dit_forward_vace, model) + sp_size = get_sequence_parallel_world_size() + else: + sp_size = 1 + + dist.barrier() + model = shard_fn(model) + sample_neg_prompt = self.config.sample_neg_prompt + + torch.cuda.empty_cache() + event = initialized_events[gpu] + in_q = in_q_list[gpu] + event.set() + + while True: + item = in_q.get() + input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \ + shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item + input_frames = self.transfer_data_to_cuda(input_frames, gpu) + input_masks = self.transfer_data_to_cuda(input_masks, gpu) + input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu) + + if n_prompt == "": + n_prompt = sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=gpu) + seed_g.manual_seed(seed) + + context = text_encoder([input_prompt], gpu) + context_null = text_encoder([n_prompt], gpu) + + # vace context encode + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae) + m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=gpu, + generator=seed_g) + ] + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (patch_size[1] * patch_size[2]) * + target_shape[1] / sp_size) * sp_size + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=gpu, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=gpu, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + model.to(gpu) + noise_pred_cond = model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[ + 0] + noise_pred_uncond = model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, + **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + torch.cuda.empty_cache() + x0 = latents + if rank == 0: + videos = self.decode_latent(x0, input_ref_images, vae=vae) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + if rank == 0: + out_q.put(videos[0].cpu()) + + except Exception as e: + trace_info = traceback.format_exc() + print(trace_info, flush=True) + print(e, flush=True) + + + + def generate(self, + input_prompt, + input_frames, + input_masks, + input_ref_images, + size=(1280, 720), + frame_num=81, + context_scale=1.0, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + + input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, + shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model) + for in_q in self.in_q_list: + in_q.put(input_data) + value_output = self.out_q.get() + + return value_output diff --git a/vae.py b/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c42f6aa659099fdf28621edd64ea3b86ace4b0 --- /dev/null +++ b/vae.py @@ -0,0 +1,707 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +# import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + 'WanVAE', +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +# class CausalConv3dNew(nn.Conv3d): +# """ +# Causal 3d convolusion. +# """ + +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self._padding = (self.padding[2], self.padding[2], self.padding[1], +# self.padding[1], 2 * self.padding[0], 0) +# self.padding = (0, 0, 0) + +# def forward(self, x, cache_x=None): +# padding = list(self._padding) +# if cache_x is not None and self._padding[4] > 0: +# cache_x = cache_x.to(x.device) +# x = torch.cat([cache_x, x], dim=2) +# padding[4] -= cache_x.shape[2] +# # x = F.pad(x, padding) + +# # return super().forward(x) + +# return F.conv3d(x,weight=self.weight, +# bias=self.bias, +# stride=self.stride, +# dilation=self.dilation, +# padding=padding) + + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0) + cfg.update(**kwargs) + + # init model + with torch.device('meta'): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f'loading {pretrained_path}') + model.load_state_dict( + torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class WanVAE(nn.Module): + + def __init__(self, + z_dim=16, + vae_pth='cache/vae_step_411000.pth', + dtype=torch.float, + device="cuda"): + super().__init__() + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + ).eval().requires_grad_(False).to(device) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with torch.autocast(device_type='cuda',dtype=self.dtype): + return self.model.encode(videos, self.scale) + + + def decode(self, zs): + with torch.autocast(device_type='cuda',dtype=self.dtype): + return self.model.decode(zs,self.scale).float().clamp_(-1, 1) + + + # def encode(self, videos): + # """ + # videos: A list of videos each with shape [C, T, H, W]. + # """ + # with torch.autocast(device_type='cuda',dtype=self.dtype): + # return [ + # self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) + # for u in videos + # ] + + # def decode(self, zs): + # with torch.autocast(device_type='cuda',dtype=self.dtype): + # return [ + # self.model.decode(u.unsqueeze(0), + # self.scale).float().clamp_(-1, 1).squeeze(0) + # for u in zs + # ] +