| from __future__ import annotations |
|
|
| import heapq |
| from typing import Dict, List, Tuple |
|
|
| import torch |
|
|
|
|
| class BatchEvenMotionPruner: |
| """ |
| Remove the most redundant interior frame from an IMAGE batch until the |
| requested batch size is reached. |
| |
| Redundancy score for an interior frame i: |
| mean_abs_diff(frame[i], frame[left_neighbor]) + |
| mean_abs_diff(frame[i], frame[right_neighbor]) |
| |
| The frame with the LOWEST score is removed first. |
| The first and last frames are never removed. |
| """ |
|
|
| CATEGORY = "image/batch" |
| RETURN_TYPES = ("IMAGE",) |
| RETURN_NAMES = ("images",) |
| FUNCTION = "prune" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "images": ("IMAGE", {}), |
| "target_count": ( |
| "INT", |
| { |
| "default": 16, |
| "min": 1, |
| "max": 4096, |
| "step": 1, |
| }, |
| ), |
| } |
| } |
|
|
| @staticmethod |
| def _validate_images(images: torch.Tensor) -> torch.Tensor: |
| if not isinstance(images, torch.Tensor): |
| raise TypeError("Expected 'images' to be a torch.Tensor.") |
|
|
| |
| if images.ndim == 3: |
| images = images.unsqueeze(0) |
| elif images.ndim != 4: |
| raise ValueError( |
| f"Expected IMAGE tensor with shape [B,H,W,C], got shape {tuple(images.shape)}." |
| ) |
|
|
| return images |
|
|
| @staticmethod |
| def _pair_key(a: int, b: int) -> Tuple[int, int]: |
| return (a, b) if a < b else (b, a) |
|
|
| def _pair_difference( |
| self, |
| images: torch.Tensor, |
| left_idx: int, |
| right_idx: int, |
| cache: Dict[Tuple[int, int], float], |
| ) -> float: |
| key = self._pair_key(left_idx, right_idx) |
| cached = cache.get(key) |
| if cached is not None: |
| return cached |
|
|
| left = images[left_idx].float() |
| right = images[right_idx].float() |
|
|
| |
| value = torch.mean(torch.abs(left - right)).item() |
| cache[key] = value |
| return value |
|
|
| def _candidate_score( |
| self, |
| images: torch.Tensor, |
| idx: int, |
| prev_idx: List[int], |
| next_idx: List[int], |
| cache: Dict[Tuple[int, int], float], |
| ) -> float: |
| left = prev_idx[idx] |
| right = next_idx[idx] |
| if left == -1 or right == -1: |
| raise ValueError("Endpoints must not be scored for removal.") |
|
|
| return ( |
| self._pair_difference(images, left, idx, cache) |
| + self._pair_difference(images, idx, right, cache) |
| ) |
|
|
| def prune(self, images: torch.Tensor, target_count: int): |
| images = self._validate_images(images) |
|
|
| batch_size = int(images.shape[0]) |
| target_count = int(target_count) |
|
|
| if batch_size <= 1 or target_count >= batch_size: |
| return (images,) |
|
|
| |
| minimum_reachable = 1 if batch_size <= 1 else 2 |
| desired_count = max(target_count, minimum_reachable) |
|
|
| if desired_count >= batch_size: |
| return (images,) |
|
|
| prev_idx = [-1] + [i - 1 for i in range(1, batch_size)] |
| next_idx = [i + 1 for i in range(batch_size - 1)] + [-1] |
| alive = [True] * batch_size |
| candidate_version = [0] * batch_size |
| pair_cache: Dict[Tuple[int, int], float] = {} |
| heap: List[Tuple[float, int, int]] = [] |
|
|
| def push_candidate(i: int) -> None: |
| if i <= 0 or i >= batch_size - 1: |
| return |
| if not alive[i]: |
| return |
| if prev_idx[i] == -1 or next_idx[i] == -1: |
| return |
|
|
| candidate_version[i] += 1 |
| score = self._candidate_score(images, i, prev_idx, next_idx, pair_cache) |
| heapq.heappush(heap, (score, i, candidate_version[i])) |
|
|
| |
| for i in range(1, batch_size - 1): |
| push_candidate(i) |
|
|
| remaining = batch_size |
|
|
| while remaining > desired_count and heap: |
| _score, idx, version = heapq.heappop(heap) |
|
|
| |
| if not alive[idx]: |
| continue |
| if candidate_version[idx] != version: |
| continue |
| if prev_idx[idx] == -1 or next_idx[idx] == -1: |
| continue |
|
|
| left = prev_idx[idx] |
| right = next_idx[idx] |
|
|
| |
| alive[idx] = False |
| remaining -= 1 |
|
|
| next_idx[left] = right |
| prev_idx[right] = left |
| prev_idx[idx] = -1 |
| next_idx[idx] = -1 |
|
|
| |
| push_candidate(left) |
| push_candidate(right) |
|
|
| keep_indices = [i for i, is_alive in enumerate(alive) if is_alive] |
| keep_tensor = torch.tensor(keep_indices, device=images.device, dtype=torch.long) |
| output = images.index_select(0, keep_tensor) |
| return (output,) |
|
|
|
|
| NODE_CLASS_MAPPINGS = { |
| "BatchEvenMotionPruner": BatchEvenMotionPruner, |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "BatchEvenMotionPruner": "Batch Even Motion Pruner", |
| } |
|
|