MyCustomNodes / Custom_Batch_Output.py
saliacoel's picture
Upload Custom_Batch_Output.py
a336c71 verified
# Save as: ComfyUI/custom_nodes/special_batch_split.py
# Restart ComfyUI after saving.
import torch
class Custom_Batch_Output:
"""
Input:
- images (IMAGE batch, typically torch.Tensor [B, H, W, C])
Outputs:
- Batch_Up: [ ID 7 ] + [ IDs 9..25 ] + [ IDs 27..31 ] + [ IDs 33..36 ]
- Rife_x3: [ ID 4 ] + [ ID 37 ] (2-image batch)
Indexing is 0-based and ranges are inclusive (e.g., 9..25 includes both 9 and 25).
Safety behavior:
- If the input batch is too small (needs at least indices up to 37 => B >= 38),
or input is not a proper IMAGE tensor, the node returns the original input batch
for BOTH outputs.
"""
CATEGORY = "image/batch"
FUNCTION = "make_special_batches"
RETURN_TYPES = ("IMAGE", "IMAGE")
RETURN_NAMES = ("Batch_Up", "Rife_x3")
@classmethod
def INPUT_TYPES(cls):
return {"required": {"images": ("IMAGE",)}}
@staticmethod
def _normalize_to_batch(images: torch.Tensor) -> torch.Tensor:
# Accept single image [H,W,C] and convert to [1,H,W,C]
if images.dim() == 3:
return images.unsqueeze(0)
return images
def make_special_batches(self, images):
# Basic validation + safe fallback
if not isinstance(images, torch.Tensor):
return (images, images)
images = self._normalize_to_batch(images)
# Expect [B,H,W,C]
if images.dim() != 4:
return (images, images)
b = int(images.shape[0])
# Need indices up to 37 => batch size at least 38
if b < 38:
return (images, images)
# Build Batch_Up indices (inclusive ranges)
batch_up_indices = (
[7]
+ list(range(9, 26)) # 9..25
+ list(range(27, 32)) # 27..31
+ list(range(33, 37)) # 33..36
)
# Build Rife_x3 indices
rife_x3_indices = [4, 37]
# Gather using index_select (works on GPU/CPU, preserves dtype/device)
device = images.device
idx_up = torch.tensor(batch_up_indices, dtype=torch.long, device=device)
idx_rife = torch.tensor(rife_x3_indices, dtype=torch.long, device=device)
batch_up = torch.index_select(images, 0, idx_up).clone()
rife_x3 = torch.index_select(images, 0, idx_rife).clone()
return (batch_up, rife_x3)
NODE_CLASS_MAPPINGS = {
"Custom_Batch_Output": Custom_Batch_Output,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Custom_Batch_Output": "Custom_Batch_Output",
}