saliacoel commited on
Commit
fa43fe0
·
verified ·
1 Parent(s): a7b7bed

Delete Get_Correct_Batch_Img.py

Browse files
Files changed (1) hide show
  1. Get_Correct_Batch_Img.py +0 -171
Get_Correct_Batch_Img.py DELETED
@@ -1,171 +0,0 @@
1
- import torch
2
-
3
-
4
- class Get_Correct_Batch_Img:
5
- """
6
- Given a batch of RGBA images, selects:
7
- - the sprite with the widest visible span along a given Y row (max_img)
8
- - the sprite with the thinnest visible span along that same row (min_img)
9
- - the sprite whose width is closest to the midpoint between min/max widths (avg_img)
10
-
11
- Visibility is determined from the alpha channel (A > 0).
12
- Only images within [start_index, end_index] (inclusive) are considered.
13
- """
14
-
15
- # Where this node appears in the right-click menu:
16
- CATEGORY = "image/batch"
17
-
18
- @classmethod
19
- def INPUT_TYPES(s):
20
- return {
21
- "required": {
22
- # RGBA image batch: torch.Tensor [B, H, W, 4]
23
- "images": ("IMAGE",),
24
-
25
- # Sub-batch start index (inclusive, 0-based)
26
- "start_index": (
27
- "INT",
28
- {
29
- "default": 0,
30
- "min": 0,
31
- "max": 2_147_483_647,
32
- "step": 1,
33
- },
34
- ),
35
-
36
- # Sub-batch end index (inclusive, 0-based)
37
- "end_index": (
38
- "INT",
39
- {
40
- "default": 0,
41
- "min": 0,
42
- "max": 2_147_483_647,
43
- "step": 1,
44
- },
45
- ),
46
-
47
- # Y coordinate (row) used for the horizontal scan
48
- "y_coord": (
49
- "INT",
50
- {
51
- "default": 0,
52
- "min": 0,
53
- "max": 2_147_483_647,
54
- "step": 1,
55
- },
56
- ),
57
- }
58
- }
59
-
60
- # Three RGBA images out now
61
- RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
62
- RETURN_NAMES = ("max_img", "min_img", "avg_img")
63
- FUNCTION = "select"
64
-
65
- def select(self, images, start_index, end_index, y_coord):
66
- # Basic sanity checks
67
- if not isinstance(images, torch.Tensor):
68
- raise TypeError(f"Expected IMAGE tensor, got {type(images)}")
69
-
70
- if images.ndim != 4:
71
- raise ValueError(
72
- f"Expected IMAGE of shape [B,H,W,C], got {tuple(images.shape)}"
73
- )
74
-
75
- batch_size, height, width, channels = images.shape
76
-
77
- if channels != 4:
78
- raise ValueError(
79
- f"Expected RGBA image with 4 channels, got {channels}. "
80
- "Make sure your input batch is RGBA (not RGB)."
81
- )
82
-
83
- if batch_size == 0:
84
- raise ValueError("Empty image batch passed to Get_Correct_Batch_Img.")
85
-
86
- # Clamp and normalize indices
87
- start = max(0, min(int(start_index), batch_size - 1))
88
- end = max(0, min(int(end_index), batch_size - 1))
89
- if start > end:
90
- start, end = end, start # swap so start <= end
91
-
92
- # Clamp Y coordinate into image bounds
93
- y = max(0, min(int(y_coord), height - 1))
94
-
95
- # Track widest and thinnest sprite
96
- max_width = None
97
- min_width = None
98
- max_idx = start
99
- min_idx = start
100
-
101
- # For AVG: store (index, width_px) for all valid sprites
102
- widths = []
103
-
104
- # Small alpha threshold; alpha > 0 is "visible"
105
- alpha_threshold = 0.0
106
- any_visible = False
107
-
108
- # Loop over the requested sub-batch only
109
- for i in range(start, end + 1):
110
- # row_alpha shape: [W]
111
- row_alpha = images[i, y, :, 3]
112
- visible = row_alpha > alpha_threshold
113
-
114
- if not torch.any(visible):
115
- # No visible pixels on this row for this image; skip it
116
- continue
117
-
118
- any_visible = True
119
-
120
- # Indices of visible pixels along X
121
- visible_indices = torch.nonzero(visible, as_tuple=False).squeeze(1)
122
- left_x = int(visible_indices[0])
123
- right_x = int(visible_indices[-1])
124
- width_px = right_x - left_x + 1 # inclusive distance
125
-
126
- widths.append((i, width_px))
127
-
128
- # Update max width (widest sprite)
129
- if max_width is None or width_px > max_width:
130
- max_width = width_px
131
- max_idx = i
132
-
133
- # Update min width (thinnest sprite)
134
- if min_width is None or width_px < min_width:
135
- min_width = width_px
136
- min_idx = i
137
-
138
- # If nothing had visible pixels on that Y, just return the first image
139
- # in the sub-batch as all three outputs (so the node never crashes).
140
- if not any_visible:
141
- base_img = images[start].unsqueeze(0)
142
- return (base_img, base_img, base_img)
143
-
144
- # Compute midpoint between MIN and MAX widths
145
- center_width = (min_width + max_width) / 2.0
146
-
147
- # Find sprite whose width is closest to this center_width
148
- avg_idx = max_idx # default
149
- closest_diff = None
150
- for idx, w in widths:
151
- diff = abs(w - center_width)
152
- if closest_diff is None or diff < closest_diff:
153
- closest_diff = diff
154
- avg_idx = idx
155
-
156
- # Extract chosen sprites as batch size 1 (B=1, H, W, C)
157
- max_img = images[max_idx].unsqueeze(0)
158
- min_img = images[min_idx].unsqueeze(0)
159
- avg_img = images[avg_idx].unsqueeze(0)
160
-
161
- return (max_img, min_img, avg_img)
162
-
163
-
164
- # Register node with ComfyUI
165
- NODE_CLASS_MAPPINGS = {
166
- "Get_Correct_Batch_Img": Get_Correct_Batch_Img,
167
- }
168
-
169
- NODE_DISPLAY_NAME_MAPPINGS = {
170
- "Get_Correct_Batch_Img": "Get_Correct_Batch_Img (Salia)",
171
- }