| |
|
| |
|
| |
|
| | import torch
|
| |
|
| |
|
| | class Get_Batch_Range_Start_To_End:
|
| | """
|
| | Inputs:
|
| | - start_id (INT)
|
| | - end_id (INT)
|
| | - images (IMAGE batch, typically torch.Tensor [B, H, W, C])
|
| |
|
| | Outputs:
|
| | - sliced_images (IMAGE batch)
|
| | - status (STRING): "ok" or an error message
|
| | - count (INT): number of images in the *input* batch
|
| |
|
| | Behavior:
|
| | - Returns images from start_id to end_id (inclusive).
|
| | - If invalid / impossible (out of range, start>end, empty batch, etc.),
|
| | returns the original input batch unchanged, plus an error message.
|
| | """
|
| |
|
| | CATEGORY = "image/batch"
|
| | FUNCTION = "slice_batch"
|
| |
|
| | RETURN_TYPES = ("IMAGE", "STRING", "INT")
|
| | RETURN_NAMES = ("images", "status", "count")
|
| |
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "start_id": ("INT", {"default": 0, "min": 0, "max": 1_000_000, "step": 1}),
|
| | "end_id": ("INT", {"default": 0, "min": 0, "max": 1_000_000, "step": 1}),
|
| | "images": ("IMAGE",),
|
| | }
|
| | }
|
| |
|
| | def slice_batch(self, start_id, end_id, images):
|
| |
|
| | if not isinstance(images, torch.Tensor):
|
| |
|
| | return (images, "error: images is not a torch.Tensor", 0)
|
| |
|
| |
|
| | original = images
|
| | if images.dim() == 3:
|
| | images = images.unsqueeze(0)
|
| | elif images.dim() != 4:
|
| |
|
| | count = int(images.shape[0]) if images.dim() > 0 else 0
|
| | return (original, f"error: expected IMAGE with 3 or 4 dims, got {tuple(images.shape)}", count)
|
| |
|
| | b = int(images.shape[0])
|
| |
|
| | if b <= 0:
|
| | return (images, "error: empty batch (B=0)", 0)
|
| |
|
| |
|
| | if start_id > end_id:
|
| | return (images, f"error: start_id > end_id ({start_id} > {end_id})", b)
|
| |
|
| | if start_id < 0 or end_id < 0:
|
| | return (images, f"error: negative index not allowed (start_id={start_id}, end_id={end_id})", b)
|
| |
|
| | if start_id >= b or end_id >= b:
|
| | return (
|
| | images,
|
| | f"error: out of range (start_id={start_id}, end_id={end_id}, batch_size={b})",
|
| | b,
|
| | )
|
| |
|
| |
|
| | sliced = images[start_id : end_id + 1].clone()
|
| | return (sliced, "ok", b)
|
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = {
|
| | "Get_Batch_Range_Start_To_End": Get_Batch_Range_Start_To_End,
|
| | }
|
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = {
|
| | "Get_Batch_Range_Start_To_End": "Get Batch from Batch (From Start ID to End ID)",
|
| | }
|
| |
|