cyun9286's picture
Add application file
f53b39e
raw
history blame
26.3 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch
import torch.distributed
from sam2.modeling.sam2_base import SAM2Base
from sam2.modeling.sam2_utils import (
get_1d_sine_pe,
get_next_point,
sample_box_points,
select_closest_cond_frames,
)
from sam2.utils.misc import concat_points
from training.utils.data_utils import BatchedVideoDatapoint
class SAM2Train(SAM2Base):
def __init__(
self,
image_encoder,
memory_attention=None,
memory_encoder=None,
prob_to_use_pt_input_for_train=0.0,
prob_to_use_pt_input_for_eval=0.0,
prob_to_use_box_input_for_train=0.0,
prob_to_use_box_input_for_eval=0.0,
# if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
rand_frames_to_correct_for_train=False,
rand_frames_to_correct_for_eval=False,
# how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
# - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
# - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
# note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
# these are initial conditioning frames because as we track the video, more conditioning frames might be added
# when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
rand_init_cond_frames_for_eval=False,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
# how many additional correction points to sample (on each frame selected to be corrected)
# note that the first frame receives an initial input click (in addition to any correction clicks)
num_correction_pt_per_frame=7,
# method for point sampling during evaluation
# "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
# default to "center" to be consistent with evaluation in the SAM paper
pt_sampling_for_eval="center",
# During training, we optionally allow sampling the correction points from GT regions
# instead of the prediction error regions with a small probability. This might allow the
# model to overfit less to the error regions in training datasets
prob_to_sample_from_gt_for_train=0.0,
use_act_ckpt_iterative_pt_sampling=False,
# whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
# of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
forward_backbone_per_frame_for_eval=False,
freeze_image_encoder=False,
**kwargs,
):
super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
# Point sampler and conditioning frames
self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
logging.info(
f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
)
assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
# Initial multi-conditioning frames
self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
self.num_correction_pt_per_frame = num_correction_pt_per_frame
self.pt_sampling_for_eval = pt_sampling_for_eval
self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
# A random number generator with a fixed initial seed across GPUs
self.rng = np.random.default_rng(seed=42)
if freeze_image_encoder:
for p in self.image_encoder.parameters():
p.requires_grad = False
def forward(self, input: BatchedVideoDatapoint):
if self.training or not self.forward_backbone_per_frame_for_eval:
# precompute image features on all frames before tracking
backbone_out = self.forward_image(input.flat_img_batch)
else:
# defer image feature computation on a frame until it's being tracked
backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
backbone_out = self.prepare_prompt_inputs(backbone_out, input)
previous_stages_out = self.forward_tracking(backbone_out, input)
return previous_stages_out
def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
"""Compute the image backbone features on the fly for the given img_ids."""
# Only forward backbone on unique image ids to avoid repetitive computation
# (if `img_ids` has only one element, it's already unique so we skip this step).
if img_ids.numel() > 1:
unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
else:
unique_img_ids, inv_ids = img_ids, None
# Compute the image features on those unique image ids
image = img_batch[unique_img_ids]
backbone_out = self.forward_image(image)
(
_,
vision_feats,
vision_pos_embeds,
feat_sizes,
) = self._prepare_backbone_features(backbone_out)
# Inverse-map image features for `unique_img_ids` to the final image features
# for the original input `img_ids`.
if inv_ids is not None:
image = image[inv_ids]
vision_feats = [x[:, inv_ids] for x in vision_feats]
vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
return image, vision_feats, vision_pos_embeds, feat_sizes
def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
"""
Prepare input mask, point or box prompts. Optionally, we allow tracking from
a custom `start_frame_idx` to the end of the video (for evaluation purposes).
"""
# Load the ground-truth masks on all frames (so that we can later
# sample correction points from them)
# gt_masks_per_frame = {
# stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
# for stage_id, targets in enumerate(input.find_targets)
# }
gt_masks_per_frame = {
stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im]
for stage_id, masks in enumerate(input.masks)
}
# gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
num_frames = input.num_frames
backbone_out["num_frames"] = num_frames
# Randomly decide whether to use point inputs or mask inputs
if self.training:
prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
prob_to_use_box_input = self.prob_to_use_box_input_for_train
num_frames_to_correct = self.num_frames_to_correct_for_train
rand_frames_to_correct = self.rand_frames_to_correct_for_train
num_init_cond_frames = self.num_init_cond_frames_for_train
rand_init_cond_frames = self.rand_init_cond_frames_for_train
else:
prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
prob_to_use_box_input = self.prob_to_use_box_input_for_eval
num_frames_to_correct = self.num_frames_to_correct_for_eval
rand_frames_to_correct = self.rand_frames_to_correct_for_eval
num_init_cond_frames = self.num_init_cond_frames_for_eval
rand_init_cond_frames = self.rand_init_cond_frames_for_eval
if num_frames == 1:
# here we handle a special case for mixing video + SAM on image training,
# where we force using point input for the SAM task on static images
prob_to_use_pt_input = 1.0
num_frames_to_correct = 1
num_init_cond_frames = 1
assert num_init_cond_frames >= 1
# (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
use_pt_input = self.rng.random() < prob_to_use_pt_input
if rand_init_cond_frames and num_init_cond_frames > 1:
# randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
num_init_cond_frames = self.rng.integers(
1, num_init_cond_frames, endpoint=True
)
if (
use_pt_input
and rand_frames_to_correct
and num_frames_to_correct > num_init_cond_frames
):
# randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
# correction clicks (only for the case of point input)
num_frames_to_correct = self.rng.integers(
num_init_cond_frames, num_frames_to_correct, endpoint=True
)
backbone_out["use_pt_input"] = use_pt_input
# Sample initial conditioning frames
if num_init_cond_frames == 1:
init_cond_frames = [start_frame_idx] # starting frame
else:
# starting frame + randomly selected remaining frames (without replacement)
init_cond_frames = [start_frame_idx] + self.rng.choice(
range(start_frame_idx + 1, num_frames),
num_init_cond_frames - 1,
replace=False,
).tolist()
backbone_out["init_cond_frames"] = init_cond_frames
backbone_out["frames_not_in_init_cond"] = [
t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
]
# Prepare mask or point inputs on initial conditioning frames
backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
for t in init_cond_frames:
if not use_pt_input:
backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
else:
# During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
use_box_input = self.rng.random() < prob_to_use_box_input
if use_box_input:
points, labels = sample_box_points(
gt_masks_per_frame[t],
)
else:
# (here we only sample **one initial point** on initial conditioning frames from the
# ground-truth mask; we may sample more correction points on the fly)
points, labels = get_next_point(
gt_masks=gt_masks_per_frame[t],
pred_masks=None,
method=(
"uniform" if self.training else self.pt_sampling_for_eval
),
)
point_inputs = {"point_coords": points, "point_labels": labels}
backbone_out["point_inputs_per_frame"][t] = point_inputs
# Sample frames where we will add correction clicks on the fly
# based on the error between prediction and ground-truth masks
if not use_pt_input:
# no correction points will be sampled when using mask inputs
frames_to_add_correction_pt = []
elif num_frames_to_correct == num_init_cond_frames:
frames_to_add_correction_pt = init_cond_frames
else:
assert num_frames_to_correct > num_init_cond_frames
# initial cond frame + randomly selected remaining frames (without replacement)
extra_num = num_frames_to_correct - num_init_cond_frames
frames_to_add_correction_pt = (
init_cond_frames
+ self.rng.choice(
backbone_out["frames_not_in_init_cond"], extra_num, replace=False
).tolist()
)
backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
return backbone_out
def forward_tracking(
self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
):
"""Forward video tracking on each frame (and sample correction clicks)."""
img_feats_already_computed = backbone_out["backbone_fpn"] is not None
if img_feats_already_computed:
# Prepare the backbone features
# - vision_feats and vision_pos_embeds are in (HW)BC format
(
_,
vision_feats,
vision_pos_embeds,
feat_sizes,
) = self._prepare_backbone_features(backbone_out)
# Starting the stage loop
num_frames = backbone_out["num_frames"]
init_cond_frames = backbone_out["init_cond_frames"]
frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
# first process all the initial conditioning frames to encode them as memory,
# and then conditioning on them to track the remaining frames
processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
output_dict = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
for stage_id in processing_order:
# Get the image features for the current frames
# img_ids = input.find_inputs[stage_id].img_ids
img_ids = input.flat_obj_to_img_idx[stage_id]
if img_feats_already_computed:
# Retrieve image features according to img_ids (if they are already computed).
current_vision_feats = [x[:, img_ids] for x in vision_feats]
current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
else:
# Otherwise, compute the image features on the fly for the given img_ids
# (this might be used for evaluation on long videos to avoid backbone OOM).
(
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._prepare_backbone_features_per_frame(
input.flat_img_batch, img_ids
)
# Get output masks based on this frame's prompts and previous memory
current_out = self.track_step(
frame_idx=stage_id,
is_init_cond_frame=stage_id in init_cond_frames,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
frames_to_add_correction_pt=frames_to_add_correction_pt,
output_dict=output_dict,
num_frames=num_frames,
)
# Append the output, depending on whether it's a conditioning frame
add_output_as_cond_frame = stage_id in init_cond_frames or (
self.add_all_frames_to_correct_as_cond
and stage_id in frames_to_add_correction_pt
)
if add_output_as_cond_frame:
output_dict["cond_frame_outputs"][stage_id] = current_out
else:
output_dict["non_cond_frame_outputs"][stage_id] = current_out
if return_dict:
return output_dict
# turn `output_dict` into a list for loss function
all_frame_outputs = {}
all_frame_outputs.update(output_dict["cond_frame_outputs"])
all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
# Make DDP happy with activation checkpointing by removing unused keys
all_frame_outputs = [
{k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
]
return all_frame_outputs
def track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse=False, # tracking in reverse time order (for demo usage)
run_mem_encoder=True, # Whether to run the memory encoder on the predicted masks.
prev_sam_mask_logits=None, # The previously predicted SAM mask logits.
frames_to_add_correction_pt=None,
gt_masks=None,
):
if frames_to_add_correction_pt is None:
frames_to_add_correction_pt = []
current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse,
prev_sam_mask_logits,
)
(
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
) = sam_outputs
current_out["multistep_pred_masks"] = low_res_masks
current_out["multistep_pred_masks_high_res"] = high_res_masks
current_out["multistep_pred_multimasks"] = [low_res_multimasks]
current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
current_out["multistep_pred_ious"] = [ious]
current_out["multistep_point_inputs"] = [point_inputs]
current_out["multistep_object_score_logits"] = [object_score_logits]
# Optionally, sample correction points iteratively to correct the mask
if frame_idx in frames_to_add_correction_pt:
point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
is_init_cond_frame,
point_inputs,
gt_masks,
high_res_features,
pix_feat,
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
object_score_logits,
current_out,
)
(
_,
_,
_,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
) = final_sam_outputs
# Use the final prediction (after all correction steps for output and eval)
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
self._encode_memory_in_output(
current_vision_feats,
feat_sizes,
point_inputs,
run_mem_encoder,
high_res_masks,
object_score_logits,
current_out,
)
return current_out
def _iter_correct_pt_sampling(
self,
is_init_cond_frame,
point_inputs,
gt_masks,
high_res_features,
pix_feat_with_mem,
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
object_score_logits,
current_out,
):
assert gt_masks is not None
all_pred_masks = [low_res_masks]
all_pred_high_res_masks = [high_res_masks]
all_pred_multimasks = [low_res_multimasks]
all_pred_high_res_multimasks = [high_res_multimasks]
all_pred_ious = [ious]
all_point_inputs = [point_inputs]
all_object_score_logits = [object_score_logits]
for _ in range(self.num_correction_pt_per_frame):
# sample a new point from the error between prediction and ground-truth
# (with a small probability, directly sample from GT masks instead of errors)
if self.training and self.prob_to_sample_from_gt_for_train > 0:
sample_from_gt = (
self.rng.random() < self.prob_to_sample_from_gt_for_train
)
else:
sample_from_gt = False
# if `pred_for_new_pt` is None, only GT masks will be used for point sampling
pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
new_points, new_labels = get_next_point(
gt_masks=gt_masks,
pred_masks=pred_for_new_pt,
method="uniform" if self.training else self.pt_sampling_for_eval,
)
point_inputs = concat_points(point_inputs, new_points, new_labels)
# Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
# For tracking, this means that when the user adds a correction click, we also feed
# the tracking output mask logits along with the click as input to the SAM decoder.
mask_inputs = low_res_masks
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
sam_outputs = torch.utils.checkpoint.checkpoint(
self._forward_sam_heads,
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
use_reentrant=False,
)
else:
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
(
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
_,
object_score_logits,
) = sam_outputs
all_pred_masks.append(low_res_masks)
all_pred_high_res_masks.append(high_res_masks)
all_pred_multimasks.append(low_res_multimasks)
all_pred_high_res_multimasks.append(high_res_multimasks)
all_pred_ious.append(ious)
all_point_inputs.append(point_inputs)
all_object_score_logits.append(object_score_logits)
# Concatenate the masks along channel (to compute losses on all of them,
# using `MultiStepIteractiveMasks`)
current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
current_out["multistep_pred_masks_high_res"] = torch.cat(
all_pred_high_res_masks, dim=1
)
current_out["multistep_pred_multimasks"] = all_pred_multimasks
current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
current_out["multistep_pred_ious"] = all_pred_ious
current_out["multistep_point_inputs"] = all_point_inputs
current_out["multistep_object_score_logits"] = all_object_score_logits
return point_inputs, sam_outputs