|
|
|
|
|
|
|
|
|
|
| import torch
|
|
|
|
|
| def _clamp_index(index: int, batch_size: int) -> int:
|
| """Clamp index into [0, batch_size-1]."""
|
| if batch_size <= 0:
|
| raise ValueError("Input batch is empty (batch_size <= 0).")
|
| if index < 0 or index >= batch_size:
|
| print(
|
| f"[BatchIndexTools] index {index} out of range for batch_size {batch_size}; "
|
| f"clamping to valid range."
|
| )
|
| index = max(0, min(index, batch_size - 1))
|
| return index
|
|
|
|
|
| class BatchGetImageAtIndex:
|
| """
|
| Node 1:
|
| - Takes an IMAGE batch and an integer index
|
| - Outputs the image at that index (as a batch of size 1)
|
| Notes:
|
| - Index is zero-based (0 is the first image).
|
| - If index is out of range, it is clamped to the nearest valid index.
|
| """
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "images": ("IMAGE",),
|
| "index": ("INT", {"default": 0, "min": 0, "max": 10**9}),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE",)
|
| RETURN_NAMES = ("image",)
|
| FUNCTION = "get"
|
| CATEGORY = "Batch/Index"
|
|
|
| def get(self, images, index):
|
| if not torch.is_tensor(images):
|
| raise TypeError("Expected 'images' to be a torch Tensor (ComfyUI IMAGE type).")
|
| if images.ndim != 4:
|
| raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.")
|
|
|
| b = images.shape[0]
|
| idx = _clamp_index(int(index), b)
|
|
|
|
|
| out = images[idx : idx + 1]
|
| return (out,)
|
|
|
|
|
| class BatchReplaceImageAtIndex:
|
| """
|
| Node 2:
|
| - Takes an IMAGE batch, an integer index, and a single IMAGE
|
| - Replaces the batch item at that index with the provided image
|
| - Outputs the modified batch
|
| Notes:
|
| - Index is zero-based (0 is the first image).
|
| - If index is out of range, it is clamped to the nearest valid index.
|
| - The replacement image must have the same H/W/C as the batch images.
|
| - If 'image' is a batch, only the first image is used.
|
| """
|
|
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "images": ("IMAGE",),
|
| "index": ("INT", {"default": 0, "min": 0, "max": 10**9}),
|
| "image": ("IMAGE",),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE",)
|
| RETURN_NAMES = ("images",)
|
| FUNCTION = "replace"
|
| CATEGORY = "Batch/Index"
|
|
|
| def replace(self, images, index, image):
|
| if not torch.is_tensor(images) or not torch.is_tensor(image):
|
| raise TypeError("Expected 'images' and 'image' to be torch Tensors (ComfyUI IMAGE type).")
|
| if images.ndim != 4:
|
| raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.")
|
| if image.ndim != 4:
|
| raise ValueError(f"Expected 'image' with shape [B,H,W,C], got ndim={image.ndim}.")
|
|
|
| b = images.shape[0]
|
| idx = _clamp_index(int(index), b)
|
|
|
|
|
| replacement = image[:1]
|
|
|
|
|
| if replacement.shape[1:] != images.shape[1:]:
|
| raise ValueError(
|
| "Replacement image must match batch image shape [H,W,C]. "
|
| f"Batch has [H,W,C]={tuple(images.shape[1:])}, "
|
| f"replacement has [H,W,C]={tuple(replacement.shape[1:])}."
|
| )
|
|
|
|
|
| out = images.clone()
|
|
|
|
|
| rep0 = replacement[0].to(device=out.device, dtype=out.dtype)
|
|
|
| out[idx] = rep0
|
| return (out,)
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "BatchGetImageAtIndex": BatchGetImageAtIndex,
|
| "BatchReplaceImageAtIndex": BatchReplaceImageAtIndex,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "BatchGetImageAtIndex": "Batch: Get Image @ Index",
|
| "BatchReplaceImageAtIndex": "Batch: Replace Image @ Index",
|
| }
|
|
|