| | import math
|
| | import torch
|
| |
|
| | class SpriteHeadStabilizeX:
|
| | """
|
| | Stabilize sprite animation wiggle (X only) using a Y-band (e.g. head region).
|
| |
|
| | Align frames 1..N to frame 0 by estimating horizontal shift from alpha visibility
|
| | inside the selected Y-range.
|
| |
|
| | Methods:
|
| | - bbox_center: leftmost/rightmost visible pixel columns -> center
|
| | - alpha_com: alpha-weighted center-of-mass (recommended)
|
| | - profile_corr: phase correlation on horizontal alpha profile (very robust)
|
| | - hybrid: profile_corr with a sanity check fallback to alpha_com
|
| |
|
| | Inputs support:
|
| | - True RGBA IMAGE tensor (C>=4) => alpha taken from channel 4
|
| | - Or IMAGE (RGB) + MASK (ComfyUI LoadImage mask) => alpha derived from mask
|
| | """
|
| |
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "images": ("IMAGE", {}),
|
| |
|
| |
|
| | "y_min": ("INT", {"default": 210, "min": -99999, "max": 99999, "step": 1}),
|
| | "y_max": ("INT", {"default": 332, "min": -99999, "max": 99999, "step": 1}),
|
| |
|
| |
|
| | "alpha_threshold_8bit": ("INT", {"default": 5, "min": 0, "max": 255, "step": 1}),
|
| |
|
| | "method": (["bbox_center", "alpha_com", "profile_corr", "hybrid"], {"default": "alpha_com"}),
|
| |
|
| |
|
| |
|
| | "mask_is_inverted": ("BOOLEAN", {"default": True}),
|
| |
|
| |
|
| | "max_abs_shift": ("INT", {"default": 0, "min": 0, "max": 99999, "step": 1}),
|
| | "temporal_median": ("INT", {"default": 1, "min": 1, "max": 99, "step": 1}),
|
| |
|
| |
|
| |
|
| | "hybrid_tolerance_px": ("INT", {"default": 8, "min": 0, "max": 99999, "step": 1}),
|
| | },
|
| | "optional": {
|
| | "mask": ("MASK", {}),
|
| | }
|
| | }
|
| |
|
| | RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
| | RETURN_NAMES = ("images", "mask", "shifts_x")
|
| | FUNCTION = "stabilize"
|
| | CATEGORY = "image/sprite"
|
| | SEARCH_ALIASES = ["wiggle stabilize", "sprite stabilize", "head stabilize", "animation stabilize", "sprite jitter fix"]
|
| |
|
| |
|
| |
|
| | def _get_alpha(self, images: torch.Tensor, mask: torch.Tensor | None, mask_is_inverted: bool) -> torch.Tensor:
|
| | """
|
| | Returns alpha in [0..1], shape [B,H,W].
|
| | """
|
| | if images.dim() != 4:
|
| | raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}")
|
| | B, H, W, C = images.shape
|
| |
|
| | if C >= 4:
|
| | return images[..., 3]
|
| |
|
| | if mask is None:
|
| | raise ValueError("Need RGBA images (C>=4) OR provide a MASK input.")
|
| |
|
| | if mask.dim() == 2:
|
| | mask = mask.unsqueeze(0)
|
| | if mask.dim() != 3:
|
| | raise ValueError(f"mask must have shape [B,H,W] or [H,W], got {tuple(mask.shape)}")
|
| |
|
| | if mask.shape[1] != H or mask.shape[2] != W:
|
| | raise ValueError(f"mask H/W must match images; mask={tuple(mask.shape)} images={tuple(images.shape)}")
|
| |
|
| | if mask.shape[0] == 1 and B > 1:
|
| | mask = mask.repeat(B, 1, 1)
|
| | if mask.shape[0] != B:
|
| | raise ValueError(f"mask batch must match images batch; mask B={mask.shape[0]} images B={B}")
|
| |
|
| | alpha = 1.0 - mask if mask_is_inverted else mask
|
| | return alpha
|
| |
|
| | def _clamp_y(self, H: int, y_min: int, y_max: int) -> tuple[int, int]:
|
| | y0 = int(y_min)
|
| | y1 = int(y_max)
|
| | if y1 < y0:
|
| | y0, y1 = y1, y0
|
| | y0 = max(0, min(H - 1, y0))
|
| | y1 = max(0, min(H - 1, y1))
|
| | return y0, y1
|
| |
|
| | def _bbox_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None:
|
| | """
|
| | alpha_hw: [H,W]
|
| | Returns center X using leftmost/rightmost visible columns, or None if empty.
|
| | """
|
| |
|
| | visible = alpha_hw > thr
|
| | cols = visible.any(dim=0)
|
| | if not bool(cols.any()):
|
| | return None
|
| | W = alpha_hw.shape[1]
|
| | left = int(torch.argmax(cols.float()).item())
|
| | right = int((W - 1) - torch.argmax(torch.flip(cols, dims=[0]).float()).item())
|
| | return (left + right) / 2.0
|
| |
|
| | def _com_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None:
|
| | """
|
| | alpha_hw: [H,W]
|
| | Alpha-weighted center-of-mass of X within visible area, or None if empty.
|
| | """
|
| | W = alpha_hw.shape[1]
|
| | weights = alpha_hw
|
| | if thr > 0:
|
| | weights = weights * (weights > thr)
|
| |
|
| | profile = weights.sum(dim=0)
|
| | total = float(profile.sum().item())
|
| | if total <= 0.0:
|
| | return None
|
| |
|
| | x = torch.arange(W, device=alpha_hw.device, dtype=profile.dtype)
|
| | center = float((profile * x).sum().item() / total)
|
| | return center
|
| |
|
| | def _phase_corr_shift_x(self, alpha_hw: torch.Tensor, ref_profile: torch.Tensor, thr: float) -> int | None:
|
| | """
|
| | Estimate integer shift to APPLY to current frame (X) so it matches reference.
|
| | Uses 1D phase correlation on horizontal alpha profile.
|
| | Returns shift_x (int), or None if empty.
|
| | """
|
| | weights = alpha_hw
|
| | if thr > 0:
|
| | weights = weights * (weights > thr)
|
| |
|
| | prof = weights.sum(dim=0).float()
|
| | if float(prof.sum().item()) <= 0.0:
|
| | return None
|
| |
|
| |
|
| | prof = prof - prof.mean()
|
| | ref = ref_profile
|
| |
|
| |
|
| | F = torch.fft.rfft(prof)
|
| | R = torch.fft.rfft(ref)
|
| | cps = F * torch.conj(R)
|
| | cps = cps / (torch.abs(cps) + 1e-9)
|
| | corr = torch.fft.irfft(cps, n=prof.numel())
|
| | peak = int(torch.argmax(corr).item())
|
| |
|
| | W = prof.numel()
|
| | lag = peak if peak <= W // 2 else peak - W
|
| | shift_x = -lag
|
| | return int(shift_x)
|
| |
|
| | def _shift_frame_x(self, img_hwc: torch.Tensor, shift_x: int) -> torch.Tensor:
|
| | """
|
| | img_hwc: [H,W,C]
|
| | shift_x: int (positive -> move right)
|
| | """
|
| | H, W, C = img_hwc.shape
|
| | out = torch.zeros_like(img_hwc)
|
| | if shift_x == 0:
|
| | return img_hwc
|
| | if abs(shift_x) >= W:
|
| | return out
|
| |
|
| | if shift_x > 0:
|
| | out[:, shift_x:, :] = img_hwc[:, : W - shift_x, :]
|
| | else:
|
| | sx = -shift_x
|
| | out[:, : W - sx, :] = img_hwc[:, sx:, :]
|
| | return out
|
| |
|
| | def _shift_mask_x(self, m_hw: torch.Tensor, shift_x: int, fill_val: float) -> torch.Tensor:
|
| | """
|
| | m_hw: [H,W]
|
| | """
|
| | H, W = m_hw.shape
|
| | out = torch.full_like(m_hw, fill_val)
|
| | if shift_x == 0:
|
| | return m_hw
|
| | if abs(shift_x) >= W:
|
| | return out
|
| | if shift_x > 0:
|
| | out[:, shift_x:] = m_hw[:, : W - shift_x]
|
| | else:
|
| | sx = -shift_x
|
| | out[:, : W - sx] = m_hw[:, sx:]
|
| | return out
|
| |
|
| | def _median_smooth(self, shifts: list[int], window: int) -> list[int]:
|
| | """
|
| | Median filter over shifts with odd window size. Keeps shifts[0] unchanged.
|
| | """
|
| | if window <= 1 or len(shifts) <= 2:
|
| | return shifts
|
| | w = int(window)
|
| | if w % 2 == 0:
|
| | w += 1
|
| | r = w // 2
|
| | out = shifts[:]
|
| | out[0] = shifts[0]
|
| | n = len(shifts)
|
| | for i in range(1, n):
|
| | lo = max(1, i - r)
|
| | hi = min(n, i + r + 1)
|
| | vals = sorted(shifts[lo:hi])
|
| | out[i] = vals[len(vals) // 2]
|
| | return out
|
| |
|
| |
|
| |
|
| | def stabilize(
|
| | self,
|
| | images: torch.Tensor,
|
| | y_min: int = 210,
|
| | y_max: int = 332,
|
| | alpha_threshold_8bit: int = 5,
|
| | method: str = "alpha_com",
|
| | mask_is_inverted: bool = True,
|
| | max_abs_shift: int = 0,
|
| | temporal_median: int = 1,
|
| | hybrid_tolerance_px: int = 8,
|
| | mask: torch.Tensor | None = None,
|
| | ):
|
| | if not torch.is_tensor(images):
|
| | raise TypeError("images must be a torch.Tensor")
|
| | if images.dim() != 4:
|
| | raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}")
|
| |
|
| | B, H, W, C = images.shape
|
| | if B < 1:
|
| | raise ValueError("images batch is empty")
|
| |
|
| | alpha = self._get_alpha(images, mask, mask_is_inverted)
|
| | y0, y1 = self._clamp_y(H, y_min, y_max)
|
| | thr = float(alpha_threshold_8bit) / 255.0
|
| |
|
| | roi_alpha = alpha[:, y0:y1 + 1, :]
|
| |
|
| |
|
| | ref_roi = roi_alpha[0]
|
| |
|
| |
|
| | ref_center_bbox = None
|
| | ref_center_com = None
|
| | ref_profile = None
|
| |
|
| | if method in ("bbox_center", "hybrid"):
|
| | ref_center_bbox = self._bbox_center_x(ref_roi, thr)
|
| | if method in ("alpha_com", "hybrid"):
|
| | ref_center_com = self._com_center_x(ref_roi, thr)
|
| | if method in ("profile_corr", "hybrid"):
|
| |
|
| | w = ref_roi
|
| | if thr > 0:
|
| | w = w * (w > thr)
|
| | ref_profile = w.sum(dim=0).float()
|
| | ref_profile = ref_profile - ref_profile.mean()
|
| |
|
| |
|
| | if ref_center_bbox is None and ref_center_com is None and ref_profile is None:
|
| |
|
| | out_mask = None
|
| | if mask is not None:
|
| | out_mask = mask if mask.dim() == 3 else mask.unsqueeze(0)
|
| | elif C >= 4:
|
| | a = images[..., 3]
|
| | out_mask = (1.0 - a) if mask_is_inverted else a
|
| | else:
|
| | fill_val = 1.0 if mask_is_inverted else 0.0
|
| | out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype)
|
| |
|
| | return (images, out_mask, "[0]" if B == 1 else str([0] * B))
|
| |
|
| |
|
| |
|
| | if ref_center_com is not None:
|
| | ref_center = ref_center_com
|
| | elif ref_center_bbox is not None:
|
| | ref_center = ref_center_bbox
|
| | else:
|
| | ref_center = W / 2.0
|
| |
|
| | shifts = [0] * B
|
| | shifts[0] = 0
|
| |
|
| | for i in range(1, B):
|
| | a_hw = roi_alpha[i]
|
| |
|
| | shift_i = 0
|
| |
|
| | if method == "bbox_center":
|
| | c = self._bbox_center_x(a_hw, thr)
|
| | if c is None:
|
| | shift_i = 0
|
| | else:
|
| | shift_i = int(round(ref_center - c))
|
| |
|
| | elif method == "alpha_com":
|
| | c = self._com_center_x(a_hw, thr)
|
| | if c is None:
|
| | shift_i = 0
|
| | else:
|
| | shift_i = int(round(ref_center - c))
|
| |
|
| | elif method == "profile_corr":
|
| | s = self._phase_corr_shift_x(a_hw, ref_profile, thr)
|
| | shift_i = 0 if s is None else int(s)
|
| |
|
| | elif method == "hybrid":
|
| |
|
| | s_corr = self._phase_corr_shift_x(a_hw, ref_profile, thr) if ref_profile is not None else None
|
| |
|
| |
|
| | c = self._com_center_x(a_hw, thr)
|
| | s_com = None if c is None else int(round(ref_center - c))
|
| |
|
| | if s_corr is None and s_com is None:
|
| | shift_i = 0
|
| | elif s_corr is None:
|
| | shift_i = int(s_com)
|
| | elif s_com is None:
|
| | shift_i = int(s_corr)
|
| | else:
|
| | if abs(int(s_corr) - int(s_com)) > int(hybrid_tolerance_px):
|
| | shift_i = int(s_com)
|
| | else:
|
| | shift_i = int(s_corr)
|
| |
|
| | else:
|
| | raise ValueError(f"Unknown method: {method}")
|
| |
|
| |
|
| | if max_abs_shift and max_abs_shift > 0:
|
| | shift_i = int(max(-max_abs_shift, min(max_abs_shift, shift_i)))
|
| |
|
| | shifts[i] = shift_i
|
| |
|
| |
|
| | shifts = self._median_smooth(shifts, int(temporal_median))
|
| |
|
| |
|
| | out_images = torch.zeros_like(images)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | out_mask = None
|
| | in_mask_bhw = None
|
| | if mask is not None:
|
| | in_mask_bhw = mask
|
| | if in_mask_bhw.dim() == 2:
|
| | in_mask_bhw = in_mask_bhw.unsqueeze(0)
|
| | if in_mask_bhw.shape[0] == 1 and B > 1:
|
| | in_mask_bhw = in_mask_bhw.repeat(B, 1, 1)
|
| |
|
| | fill_val = 1.0 if mask_is_inverted else 0.0
|
| | out_mask = torch.full_like(in_mask_bhw, fill_val)
|
| |
|
| | for i in range(B):
|
| | sx = int(shifts[i])
|
| | out_images[i] = self._shift_frame_x(images[i], sx)
|
| |
|
| | if out_mask is not None and in_mask_bhw is not None:
|
| | fill_val = 1.0 if mask_is_inverted else 0.0
|
| | out_mask[i] = self._shift_mask_x(in_mask_bhw[i], sx, fill_val)
|
| |
|
| | if out_mask is None:
|
| | if out_images.shape[-1] >= 4:
|
| | a = out_images[..., 3]
|
| | out_mask = (1.0 - a) if mask_is_inverted else a
|
| | else:
|
| | fill_val = 1.0 if mask_is_inverted else 0.0
|
| | out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype)
|
| |
|
| | shifts_str = str(shifts)
|
| | return (out_images, out_mask, shifts_str)
|
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = {
|
| | "SpriteHeadStabilizeX": SpriteHeadStabilizeX,
|
| | }
|
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = {
|
| | "SpriteHeadStabilizeX": "Sprite Head Stabilize X (Batch)",
|
| | } |