SynLayers commited on
Commit
12e6363
·
verified ·
1 Parent(s): 26369cf

Upload dataset/scaleup_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset/scaleup_utils.py +546 -0
dataset/scaleup_utils.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for scaling up the PrismLayersPro-blended dataset.
3
+
4
+ This module provides utilities for:
5
+ - Loading existing blended samples
6
+ - Computing non-overlapping bounding boxes
7
+ - Generating spatial-aware captions with position words
8
+ - Layer combination and compositing
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import random
14
+ from typing import Dict, List, Tuple, Optional
15
+ from PIL import Image
16
+ import numpy as np
17
+
18
+
19
+ def load_jsonl(path: str) -> List[Dict]:
20
+ """Load JSONL file and return list of dictionaries."""
21
+ items = []
22
+ with open(path, 'r', encoding='utf-8') as f:
23
+ for line in f:
24
+ line = line.strip()
25
+ if line:
26
+ items.append(json.loads(line))
27
+ return items
28
+
29
+
30
+ def save_jsonl(items: List[Dict], path: str):
31
+ """Save list of dictionaries to JSONL file."""
32
+ with open(path, 'w', encoding='utf-8') as f:
33
+ for item in items:
34
+ f.write(json.dumps(item, ensure_ascii=False) + '\n')
35
+
36
+
37
+ def load_blended_sample(sample_dir: str) -> Optional[Dict]:
38
+ """
39
+ Load a blended sample from its directory.
40
+ Returns metadata dict with loaded layer images.
41
+ """
42
+ metadata_path = os.path.join(sample_dir, 'metadata.json')
43
+ if not os.path.exists(metadata_path):
44
+ return None
45
+
46
+ with open(metadata_path, 'r', encoding='utf-8') as f:
47
+ metadata = json.load(f)
48
+
49
+ # Load base_image (background)
50
+ base_path = os.path.join(sample_dir, 'base_image.png')
51
+ if os.path.exists(base_path):
52
+ metadata['base_image'] = Image.open(base_path).convert('RGBA')
53
+ else:
54
+ metadata['base_image'] = None
55
+
56
+ # Load layer images
57
+ metadata['layer_images'] = {}
58
+ for layer in metadata.get('layers', []):
59
+ img_path = os.path.join(sample_dir, layer['image_path'])
60
+ if os.path.exists(img_path):
61
+ metadata['layer_images'][layer['layer_idx']] = Image.open(img_path).convert('RGBA')
62
+
63
+ # Store sample directory path
64
+ metadata['sample_path'] = sample_dir
65
+
66
+ return metadata
67
+
68
+
69
+ def get_blended_sample_dirs(blended_dir: str, max_samples: Optional[int] = None) -> List[str]:
70
+ """
71
+ Get list of sample directories in the blended directory.
72
+ """
73
+ sample_dirs = []
74
+ for name in sorted(os.listdir(blended_dir)):
75
+ if name.startswith('sample_') and os.path.isdir(os.path.join(blended_dir, name)):
76
+ sample_dirs.append(os.path.join(blended_dir, name))
77
+ if max_samples and len(sample_dirs) >= max_samples:
78
+ break
79
+ return sample_dirs
80
+
81
+
82
+ def compute_overlap_area(box1: List[int], box2: List[int]) -> int:
83
+ """
84
+ Calculate the overlap area between two boxes (xyxy format).
85
+ Returns 0 if no overlap.
86
+ """
87
+ x0_1, y0_1, x1_1, y1_1 = box1
88
+ x0_2, y0_2, x1_2, y1_2 = box2
89
+
90
+ # Calculate intersection
91
+ x0_i = max(x0_1, x0_2)
92
+ y0_i = max(y0_1, y0_2)
93
+ x1_i = min(x1_1, x1_2)
94
+ y1_i = min(y1_1, y1_2)
95
+
96
+ # Check if there's an intersection
97
+ if x0_i >= x1_i or y0_i >= y1_i:
98
+ return 0
99
+
100
+ return (x1_i - x0_i) * (y1_i - y0_i)
101
+
102
+
103
+ def compute_total_overlap(box: List[int], existing_boxes: List[List[int]]) -> int:
104
+ """
105
+ Calculate total overlap area between a box and all existing boxes.
106
+ """
107
+ total = 0
108
+ for eb in existing_boxes:
109
+ total += compute_overlap_area(box, eb)
110
+ return total
111
+
112
+
113
+ def get_position_description(box: List[int], canvas_size: int) -> str:
114
+ """
115
+ Get position description for a bounding box.
116
+
117
+ Based on the box center point position, returns one of:
118
+ - "On the top-left"
119
+ - "On the top-right"
120
+ - "On the bottom-left"
121
+ - "On the bottom-right"
122
+ - "In the center"
123
+ - "At the top"
124
+ - "At the bottom"
125
+ - "On the left"
126
+ - "On the right"
127
+ """
128
+ x0, y0, x1, y1 = box
129
+ center_x = (x0 + x1) / 2
130
+ center_y = (y0 + y1) / 2
131
+
132
+ # Normalize to 0-1 range
133
+ norm_x = center_x / canvas_size
134
+ norm_y = center_y / canvas_size
135
+
136
+ # Define regions (3x3 grid)
137
+ # Left: 0-0.33, Center: 0.33-0.67, Right: 0.67-1.0
138
+ # Top: 0-0.33, Middle: 0.33-0.67, Bottom: 0.67-1.0
139
+
140
+ if norm_y < 0.33:
141
+ if norm_x < 0.33:
142
+ return "On the top-left"
143
+ elif norm_x > 0.67:
144
+ return "On the top-right"
145
+ else:
146
+ return "At the top"
147
+ elif norm_y > 0.67:
148
+ if norm_x < 0.33:
149
+ return "On the bottom-left"
150
+ elif norm_x > 0.67:
151
+ return "On the bottom-right"
152
+ else:
153
+ return "At the bottom"
154
+ else:
155
+ if norm_x < 0.33:
156
+ return "On the left"
157
+ elif norm_x > 0.67:
158
+ return "On the right"
159
+ else:
160
+ return "In the center"
161
+
162
+
163
+ def build_spatial_aware_caption(layers: List[Dict], canvas_size: int, base_caption: str = "") -> str:
164
+ """
165
+ Build a spatial-aware whole caption by adding position descriptions to each layer.
166
+
167
+ Example output:
168
+ "On the top-left, a red balloon. In the center, a clown character. At the bottom, Text: hello world."
169
+
170
+ This structured format with spatial information helps diffusion models (especially Flux with T5)
171
+ better understand the position-layer correspondence.
172
+ """
173
+ parts = []
174
+
175
+ # Add base caption if provided (shortened version)
176
+ if base_caption:
177
+ # Take only the first sentence of base caption to keep it concise
178
+ first_sentence = base_caption.split('.')[0].strip()
179
+ if first_sentence:
180
+ parts.append(first_sentence + ".")
181
+
182
+ # Add layer descriptions with position
183
+ for layer in layers:
184
+ caption = layer.get('caption', '').strip()
185
+ if not caption:
186
+ continue
187
+
188
+ box = layer.get('box', [0, 0, canvas_size, canvas_size])
189
+ position = get_position_description(box, canvas_size)
190
+
191
+ # Clean up caption - remove leading "The picture/image features" etc.
192
+ caption_clean = caption
193
+ prefixes_to_remove = [
194
+ "The picture features ",
195
+ "The image features ",
196
+ "Text ",
197
+ ]
198
+ for prefix in prefixes_to_remove:
199
+ if caption_clean.startswith(prefix):
200
+ caption_clean = caption_clean[len(prefix):]
201
+ break
202
+
203
+ # Capitalize first letter
204
+ if caption_clean:
205
+ caption_clean = caption_clean[0].upper() + caption_clean[1:] if len(caption_clean) > 1 else caption_clean.upper()
206
+
207
+ # Remove trailing period if present
208
+ caption_clean = caption_clean.rstrip('.')
209
+
210
+ parts.append(f"{position}, {caption_clean}.")
211
+
212
+ return " ".join(parts)
213
+
214
+
215
+ def compute_random_box_xyxy(
216
+ canvas_size: int,
217
+ min_size_ratio: float = 0.10,
218
+ max_size_ratio: float = 0.25,
219
+ aspect_ratio_range: Tuple[float, float] = (0.5, 2.0),
220
+ center_margin: int = 16
221
+ ) -> List[int]:
222
+ """
223
+ Compute a random bounding box in xyxy format [x0, y0, x1, y1].
224
+
225
+ Args:
226
+ canvas_size: Size of the canvas (e.g., 512)
227
+ min_size_ratio: Minimum size as ratio of canvas
228
+ max_size_ratio: Maximum size as ratio of canvas
229
+ aspect_ratio_range: Range of aspect ratios (width/height)
230
+ center_margin: Margin from edges for box center (e.g., 16 means center
231
+ must be within [16, canvas_size-16] range, i.e., 480x480 area for 512 canvas)
232
+ """
233
+ min_size = int(canvas_size * min_size_ratio)
234
+ max_size = int(canvas_size * max_size_ratio)
235
+
236
+ # Random aspect ratio
237
+ aspect_ratio = random.uniform(*aspect_ratio_range)
238
+
239
+ if aspect_ratio >= 1.0:
240
+ w = random.randint(min_size, max_size)
241
+ h = int(w / aspect_ratio)
242
+ else:
243
+ h = random.randint(min_size, max_size)
244
+ w = int(h * aspect_ratio)
245
+
246
+ # Clamp to valid range
247
+ w = max(min_size, min(w, max_size))
248
+ h = max(min_size, min(h, max_size))
249
+
250
+ # Random center position within the allowed region (canvas_size - 2*margin)
251
+ # For 512 canvas with margin=16, center can be in [16, 496]
252
+ min_center = center_margin
253
+ max_center = canvas_size - center_margin
254
+
255
+ # Ensure we have valid range
256
+ if max_center <= min_center:
257
+ max_center = canvas_size - 1
258
+ min_center = 0
259
+
260
+ center_x = random.randint(min_center, max_center)
261
+ center_y = random.randint(min_center, max_center)
262
+
263
+ # Convert center to box coordinates
264
+ x0 = center_x - w // 2
265
+ y0 = center_y - h // 2
266
+ x1 = x0 + w
267
+ y1 = y0 + h
268
+
269
+ # Clamp to canvas bounds (box can extend to edges, just center is constrained)
270
+ x0 = max(0, x0)
271
+ y0 = max(0, y0)
272
+ x1 = min(canvas_size, x1)
273
+ y1 = min(canvas_size, y1)
274
+
275
+ return [x0, y0, x1, y1]
276
+
277
+
278
+ def compute_non_overlapping_box_xyxy(
279
+ canvas_size: int,
280
+ existing_boxes: List[List[int]],
281
+ min_size_ratio: float = 0.10,
282
+ max_size_ratio: float = 0.25,
283
+ max_attempts: int = 300,
284
+ max_overlap_ratio: float = 0.20,
285
+ center_margin: int = 16
286
+ ) -> List[int]:
287
+ """
288
+ Compute a box (xyxy) that minimizes overlap with existing boxes.
289
+
290
+ Args:
291
+ canvas_size: Size of the canvas (e.g., 512)
292
+ existing_boxes: List of existing boxes to avoid overlapping with
293
+ min_size_ratio: Minimum size as ratio of canvas
294
+ max_size_ratio: Maximum size as ratio of canvas
295
+ max_attempts: Maximum attempts to find a good position
296
+ max_overlap_ratio: Maximum acceptable overlap ratio (default 20%)
297
+ center_margin: Margin from edges for box center (default 16px, so center
298
+ is within 480x480 area for 512 canvas)
299
+
300
+ Strategy:
301
+ 1. Try to find a position with no overlap
302
+ 2. If not possible, accept positions with < max_overlap_ratio overlap
303
+ 3. Return the position with minimum overlap
304
+ """
305
+ best_box = None
306
+ best_overlap_ratio = float('inf')
307
+
308
+ for _ in range(max_attempts):
309
+ box = compute_random_box_xyxy(
310
+ canvas_size, min_size_ratio, max_size_ratio,
311
+ center_margin=center_margin
312
+ )
313
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
314
+
315
+ if box_area <= 0:
316
+ continue
317
+
318
+ overlap = compute_total_overlap(box, existing_boxes)
319
+ overlap_ratio = overlap / box_area
320
+
321
+ # If no overlap, return immediately
322
+ if overlap == 0:
323
+ return box
324
+
325
+ # Track best box
326
+ if overlap_ratio < best_overlap_ratio:
327
+ best_overlap_ratio = overlap_ratio
328
+ best_box = box
329
+
330
+ # Accept if overlap is small enough
331
+ if overlap_ratio < max_overlap_ratio:
332
+ return box
333
+
334
+ # Return the best box found
335
+ if best_box is not None:
336
+ return best_box
337
+
338
+ # Fallback
339
+ return compute_random_box_xyxy(
340
+ canvas_size, min_size_ratio, max_size_ratio,
341
+ center_margin=center_margin
342
+ )
343
+
344
+
345
+ def create_layer_on_canvas(
346
+ layer_img: Image.Image,
347
+ box: List[int],
348
+ canvas_size: int
349
+ ) -> Image.Image:
350
+ """
351
+ Create a full-canvas RGBA image with the layer placed at box position.
352
+ Box is in xyxy format: [x0, y0, x1, y1].
353
+ Layer will have transparent background.
354
+ """
355
+ x0, y0, x1, y1 = box
356
+ w = x1 - x0
357
+ h = y1 - y0
358
+
359
+ # Create transparent canvas
360
+ canvas = Image.new('RGBA', (canvas_size, canvas_size), (0, 0, 0, 0))
361
+
362
+ # Ensure positive dimensions
363
+ if w <= 0 or h <= 0:
364
+ return canvas
365
+
366
+ # Resize layer to fit box
367
+ layer_resized = layer_img.resize((w, h), Image.LANCZOS)
368
+
369
+ # Paste with alpha (preserving transparency)
370
+ if layer_resized.mode == 'RGBA':
371
+ canvas.paste(layer_resized, (x0, y0), layer_resized)
372
+ else:
373
+ layer_resized = layer_resized.convert('RGBA')
374
+ canvas.paste(layer_resized, (x0, y0), layer_resized)
375
+
376
+ return canvas
377
+
378
+
379
+ def get_content_bbox(img: Image.Image) -> Optional[List[int]]:
380
+ """
381
+ Get the tight bounding box of non-transparent content in an RGBA image.
382
+ Returns [x0, y0, x1, y1] or None if the image is fully transparent.
383
+ """
384
+ arr = np.array(img.convert('RGBA'))
385
+ alpha = arr[:, :, 3]
386
+ rows = np.any(alpha > 0, axis=1)
387
+ cols = np.any(alpha > 0, axis=0)
388
+ if not rows.any() or not cols.any():
389
+ return None
390
+ rmin, rmax = np.where(rows)[0][[0, -1]]
391
+ cmin, cmax = np.where(cols)[0][[0, -1]]
392
+ return [int(cmin), int(rmin), int(cmax + 1), int(rmax + 1)]
393
+
394
+
395
+ def get_box_size(box: List[int]) -> Tuple[int, int]:
396
+ """Get width and height from xyxy box."""
397
+ x0, y0, x1, y1 = box
398
+ return (x1 - x0, y1 - y0)
399
+
400
+
401
+ def load_caption_list(caption_jsonl: str) -> List[Dict]:
402
+ """
403
+ Load captions.jsonl as a list (ordered by line number).
404
+ """
405
+ return load_jsonl(caption_jsonl)
406
+
407
+
408
+ def get_laion_caption_from_json(image_path: str) -> str:
409
+ """
410
+ Get LAION image caption from its corresponding .json file.
411
+ """
412
+ json_path = image_path.rsplit('.', 1)[0] + '.json'
413
+
414
+ if os.path.exists(json_path):
415
+ try:
416
+ with open(json_path, 'r', encoding='utf-8') as f:
417
+ data = json.load(f)
418
+ return data.get('caption', '')
419
+ except Exception:
420
+ pass
421
+
422
+ return os.path.basename(image_path).rsplit('.', 1)[0]
423
+
424
+
425
+ def get_laion_images_with_captions(laion_dir: str, laion_jsonl: Optional[str] = None) -> List[Tuple[str, str]]:
426
+ """
427
+ Get all LAION images with their captions.
428
+ """
429
+ images = []
430
+
431
+ for subdir in sorted(os.listdir(laion_dir)):
432
+ subdir_path = os.path.join(laion_dir, subdir)
433
+ if os.path.isdir(subdir_path):
434
+ for fname in sorted(os.listdir(subdir_path)):
435
+ if fname.endswith(('.jpg', '.jpeg', '.png')):
436
+ img_path = os.path.join(subdir_path, fname)
437
+ caption = get_laion_caption_from_json(img_path)
438
+ images.append((img_path, caption))
439
+
440
+ return images
441
+
442
+
443
+ def get_caption_images_with_text(caption_dir: str, caption_list: List[Dict]) -> List[Tuple[str, str]]:
444
+ """
445
+ Get caption images with their text content.
446
+ """
447
+ images = []
448
+
449
+ for fname in sorted(os.listdir(caption_dir)):
450
+ if fname.endswith('.png'):
451
+ img_path = os.path.join(caption_dir, fname)
452
+
453
+ idx_str = fname.split('.')[0]
454
+ try:
455
+ idx = int(idx_str)
456
+ except ValueError:
457
+ idx = -1
458
+
459
+ caption_text = ""
460
+ if 0 <= idx < len(caption_list):
461
+ caption_text = caption_list[idx].get('caption', '')
462
+
463
+ images.append((img_path, caption_text))
464
+
465
+ return images
466
+
467
+
468
+ def extract_layer_from_sample(
469
+ sample_metadata: Dict,
470
+ layer_idx: int
471
+ ) -> Optional[Tuple[Image.Image, Dict]]:
472
+ """
473
+ Extract a specific layer from a sample.
474
+ Returns (layer_image, layer_info) or None if not found.
475
+ """
476
+ layer_images = sample_metadata.get('layer_images', {})
477
+
478
+ if layer_idx not in layer_images:
479
+ return None
480
+
481
+ # Find layer info
482
+ for layer in sample_metadata.get('layers', []):
483
+ if layer['layer_idx'] == layer_idx:
484
+ return (layer_images[layer_idx], layer.copy())
485
+
486
+ return None
487
+
488
+
489
+ def select_random_layers_from_samples(
490
+ sample_dirs: List[str],
491
+ exclude_sample: str,
492
+ num_samples_to_pick: int = 2,
493
+ num_layers_per_sample: Tuple[int, int] = (1, 2)
494
+ ) -> List[Tuple[Image.Image, Dict, str]]:
495
+ """
496
+ Select random layers from random samples.
497
+
498
+ Args:
499
+ sample_dirs: List of all sample directories
500
+ exclude_sample: Sample directory to exclude (the base sample)
501
+ num_samples_to_pick: Number of different samples to pick from (2-3)
502
+ num_layers_per_sample: Range of layers to pick from each sample (min, max)
503
+
504
+ Returns:
505
+ List of (layer_image, layer_info, source_sample) tuples
506
+ """
507
+ # Filter out the base sample
508
+ available_samples = [s for s in sample_dirs if s != exclude_sample]
509
+
510
+ if len(available_samples) < num_samples_to_pick:
511
+ num_samples_to_pick = len(available_samples)
512
+
513
+ # Randomly select samples
514
+ selected_samples = random.sample(available_samples, num_samples_to_pick)
515
+
516
+ collected_layers = []
517
+
518
+ for sample_dir in selected_samples:
519
+ # Load sample
520
+ sample_meta = load_blended_sample(sample_dir)
521
+ if sample_meta is None:
522
+ continue
523
+
524
+ # Get available layers (excluding laion_foreground and caption types to avoid duplicates)
525
+ layers = sample_meta.get('layers', [])
526
+ prism_layers = [l for l in layers if l.get('type') is None] # Original prism layers only
527
+
528
+ if not prism_layers:
529
+ continue
530
+
531
+ # Randomly select how many layers to pick
532
+ min_layers, max_layers = num_layers_per_sample
533
+ num_to_pick = random.randint(min_layers, min(max_layers, len(prism_layers)))
534
+
535
+ # Select random layers
536
+ selected_layers = random.sample(prism_layers, num_to_pick)
537
+
538
+ for layer_info in selected_layers:
539
+ layer_idx = layer_info['layer_idx']
540
+ layer_img = sample_meta.get('layer_images', {}).get(layer_idx)
541
+
542
+ if layer_img is not None:
543
+ collected_layers.append((layer_img, layer_info.copy(), sample_dir))
544
+
545
+ return collected_layers
546
+