SynLayers commited on
Commit
5ed5a04
·
verified ·
1 Parent(s): 919762b

Upload tools/sample_backgrounds.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tools/sample_backgrounds.py +918 -0
tools/sample_backgrounds.py ADDED
@@ -0,0 +1,918 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import random
7
+ import subprocess
8
+ from io import BytesIO
9
+
10
+ import pyarrow as pa
11
+ import pyarrow.parquet as pq
12
+ from PIL import Image
13
+
14
+ from tools.dataset import BackgroundDataset, BackgroundIterableDataset
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def iter_samples(dataset, streaming):
21
+ if streaming:
22
+ for sample in dataset:
23
+ yield sample
24
+ else:
25
+ for idx in range(len(dataset)):
26
+ yield dataset[idx]
27
+
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(description="Sample background images for SynLayers.")
31
+ parser.add_argument("--dataset-name", default="laion/laion2B-en-aesthetic")
32
+ parser.add_argument(
33
+ "--data-files",
34
+ default="/project/llmsvgen/share/data/kmw_layered_dataset/laion2B-en-aesthetic-image/*.parquet",
35
+ help="Parquet glob or list file.",
36
+ )
37
+ parser.add_argument("--split", default="train")
38
+ parser.add_argument("--cache-dir", default=None)
39
+ parser.add_argument("--url-column", default="URL")
40
+ parser.add_argument("--text-column", default="TEXT")
41
+ parser.add_argument("--hash-column", default="hash")
42
+ parser.add_argument(
43
+ "--image-root",
44
+ default="/project/llmsvgen/share/data/kmw_layered_dataset/laion2B-en-aesthetic-image",
45
+ help="Local directory with downloaded images named by hash.",
46
+ )
47
+ parser.add_argument(
48
+ "--image-extensions",
49
+ default=".jpg,.png,.jpeg,.webp",
50
+ help="Comma-separated extensions to try for local images.",
51
+ )
52
+ parser.add_argument("--image-size", type=int, default=None)
53
+ parser.add_argument("--count", type=int, default=10)
54
+ parser.add_argument("--streaming", action="store_true")
55
+ parser.add_argument("--output-dir", default="./outputs/backgrounds")
56
+ parser.add_argument(
57
+ "--save-images",
58
+ action="store_true",
59
+ help="Save images if found in image-root.",
60
+ )
61
+ parser.add_argument(
62
+ "--download",
63
+ action="store_true",
64
+ help="Download a subset into image-root using img2dataset.",
65
+ )
66
+ parser.add_argument(
67
+ "--download-mode",
68
+ choices=["auto", "img2dataset", "embedded"],
69
+ default="auto",
70
+ help="Download mode: auto-detect URL vs embedded bytes.",
71
+ )
72
+ parser.add_argument("--processes", type=int, default=8)
73
+ parser.add_argument("--threads", type=int, default=32)
74
+ parser.add_argument("--resize", type=int, default=512)
75
+ parser.add_argument("--build-splits", action="store_true")
76
+ parser.add_argument("--train-count", type=int, default=19000)
77
+ parser.add_argument("--val-count", type=int, default=1000)
78
+ parser.add_argument("--test-count", type=int, default=200)
79
+ parser.add_argument(
80
+ "--skip-existing",
81
+ action="store_true",
82
+ help="Skip downloading/extracting images that already exist in image-root.",
83
+ )
84
+ parser.add_argument(
85
+ "--progress-interval",
86
+ type=int,
87
+ default=500,
88
+ help="Log progress every N extracted images.",
89
+ )
90
+ parser.add_argument(
91
+ "--embedded-image-column",
92
+ default="whole_image",
93
+ help="Struct column containing embedded image bytes.",
94
+ )
95
+ parser.add_argument(
96
+ "--embedded-image-columns",
97
+ default=None,
98
+ help="Comma-separated embedded image columns to try in order.",
99
+ )
100
+ parser.add_argument(
101
+ "--embedded-image-bytes-key",
102
+ default="bytes",
103
+ help="Key inside embedded image struct that stores raw bytes.",
104
+ )
105
+ parser.add_argument(
106
+ "--embedded-image-path-key",
107
+ default="path",
108
+ help="Key inside embedded image struct that stores a path (if any).",
109
+ )
110
+ parser.add_argument(
111
+ "--embedded-caption-column",
112
+ default="whole_caption",
113
+ help="Caption column for embedded images.",
114
+ )
115
+ parser.add_argument(
116
+ "--embedded-id-column",
117
+ default="id",
118
+ help="ID column for embedded images.",
119
+ )
120
+ parser.add_argument(
121
+ "--size-multiple",
122
+ type=int,
123
+ default=8,
124
+ help="Round width/height up to a multiple of this value.",
125
+ )
126
+ parser.add_argument("--seed", type=int, default=42)
127
+ parser.add_argument(
128
+ "--sequential",
129
+ action="store_true",
130
+ help="Use dataset order instead of random sampling when building splits.",
131
+ )
132
+ parser.add_argument(
133
+ "--allow-partial",
134
+ action="store_true",
135
+ help="Allow writing splits even if there are not enough images.",
136
+ )
137
+ parser.add_argument(
138
+ "--id-as-path",
139
+ action="store_true",
140
+ help="Store image path in the id field instead of the raw key.",
141
+ )
142
+ return parser.parse_args()
143
+
144
+
145
+ def main():
146
+ logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
147
+ args = parse_args()
148
+
149
+ image_extensions = [ext.strip() for ext in args.image_extensions.split(",") if ext.strip()]
150
+
151
+ if args.download:
152
+ parquet_files = _expand_parquet_files(args.data_files)
153
+ if not parquet_files:
154
+ raise ValueError("No parquet files found. Check --data-files.")
155
+ os.makedirs(args.image_root, exist_ok=True)
156
+ download_mode = args.download_mode
157
+ if args.embedded_image_columns:
158
+ embedded_image_columns = [
159
+ col.strip() for col in args.embedded_image_columns.split(",") if col.strip()
160
+ ]
161
+ else:
162
+ embedded_image_columns = [args.embedded_image_column]
163
+ if download_mode == "auto":
164
+ if _parquet_has_column(parquet_files, args.url_column):
165
+ download_mode = "img2dataset"
166
+ elif any(
167
+ _parquet_has_column(parquet_files, col) for col in embedded_image_columns
168
+ ):
169
+ download_mode = "embedded"
170
+ else:
171
+ raise ValueError(
172
+ "Could not detect download mode: missing URL and embedded image columns."
173
+ )
174
+ if download_mode == "img2dataset":
175
+ url_list_path = _prepare_download_parquet(
176
+ parquet_files=parquet_files,
177
+ output_dir=args.output_dir,
178
+ count=args.count,
179
+ seed=args.seed,
180
+ url_column=args.url_column,
181
+ text_column=args.text_column,
182
+ hash_column=args.hash_column,
183
+ )
184
+ cmd = [
185
+ "img2dataset",
186
+ "--url_list",
187
+ url_list_path,
188
+ "--input_format",
189
+ "parquet",
190
+ "--url_col",
191
+ args.url_column,
192
+ "--caption_col",
193
+ args.text_column,
194
+ "--output_format",
195
+ "files",
196
+ "--output_folder",
197
+ args.image_root,
198
+ "--processes_count",
199
+ str(args.processes),
200
+ "--thread_count",
201
+ str(args.threads),
202
+ "--image_size",
203
+ str(args.resize),
204
+ "--resize_mode",
205
+ "keep_ratio",
206
+ ]
207
+ logger.info("Downloading %d images into %s", args.count, args.image_root)
208
+ subprocess.run(cmd, check=True)
209
+ else:
210
+ logger.info(
211
+ "Extracting %d embedded images into %s",
212
+ args.count,
213
+ args.image_root,
214
+ )
215
+ download_embedded_images(
216
+ parquet_files=parquet_files,
217
+ image_root=args.image_root,
218
+ output_dir=args.output_dir,
219
+ count=args.count,
220
+ seed=args.seed,
221
+ sequential=args.sequential,
222
+ id_column=args.embedded_id_column,
223
+ caption_column=args.embedded_caption_column,
224
+ image_columns=embedded_image_columns,
225
+ image_bytes_key=args.embedded_image_bytes_key,
226
+ image_path_key=args.embedded_image_path_key,
227
+ image_extensions=image_extensions,
228
+ skip_existing=args.skip_existing,
229
+ progress_interval=args.progress_interval,
230
+ )
231
+
232
+ if args.build_splits:
233
+ if _has_img2dataset_parquet(args.image_root):
234
+ build_splits_from_img2dataset(
235
+ image_root=args.image_root,
236
+ output_dir=args.output_dir,
237
+ train_count=args.train_count,
238
+ val_count=args.val_count,
239
+ test_count=args.test_count,
240
+ seed=args.seed,
241
+ sequential=args.sequential,
242
+ allow_partial=args.allow_partial,
243
+ id_as_path=args.id_as_path,
244
+ image_extensions=image_extensions,
245
+ size_multiple=args.size_multiple,
246
+ )
247
+ else:
248
+ build_splits(
249
+ data_files=args.data_files,
250
+ image_root=args.image_root,
251
+ image_extensions=image_extensions,
252
+ output_dir=args.output_dir,
253
+ train_count=args.train_count,
254
+ val_count=args.val_count,
255
+ test_count=args.test_count,
256
+ seed=args.seed,
257
+ url_column=args.url_column,
258
+ text_column=args.text_column,
259
+ hash_column=args.hash_column,
260
+ sequential=args.sequential,
261
+ allow_partial=args.allow_partial,
262
+ size_multiple=args.size_multiple,
263
+ )
264
+ return
265
+
266
+ if args.streaming:
267
+ dataset = BackgroundIterableDataset(
268
+ dataset_name=args.dataset_name,
269
+ data_files=args.data_files,
270
+ split=args.split,
271
+ cache_dir=args.cache_dir,
272
+ url_column=args.url_column,
273
+ text_column=args.text_column,
274
+ hash_column=args.hash_column,
275
+ image_root=args.image_root,
276
+ image_extensions=image_extensions,
277
+ image_size=args.image_size,
278
+ require_image=args.save_images,
279
+ )
280
+ else:
281
+ dataset = BackgroundDataset(
282
+ dataset_name=args.dataset_name,
283
+ data_files=args.data_files,
284
+ split=args.split,
285
+ cache_dir=args.cache_dir,
286
+ url_column=args.url_column,
287
+ text_column=args.text_column,
288
+ hash_column=args.hash_column,
289
+ image_root=args.image_root,
290
+ image_extensions=image_extensions,
291
+ image_size=args.image_size,
292
+ max_items=args.count * 5,
293
+ require_image=args.save_images,
294
+ )
295
+
296
+ os.makedirs(args.output_dir, exist_ok=True)
297
+ captions_path = os.path.join(args.output_dir, "captions.jsonl")
298
+
299
+ saved = 0
300
+ with open(captions_path, "w", encoding="utf-8") as captions_file:
301
+ for sample in iter_samples(dataset, args.streaming):
302
+ image = sample.get("image")
303
+ filename = None
304
+ if args.save_images:
305
+ if image is None:
306
+ logger.warning("Skipping sample: local image not found.")
307
+ continue
308
+ filename = f"background_{saved:03d}.png"
309
+ image.save(os.path.join(args.output_dir, filename))
310
+ captions_file.write(
311
+ json.dumps(
312
+ {
313
+ "file": filename,
314
+ "url": sample.get("url"),
315
+ "text": sample.get("text"),
316
+ "width": sample.get("width"),
317
+ "height": sample.get("height"),
318
+ "hash": sample.get("hash"),
319
+ "aesthetic": sample.get("aesthetic"),
320
+ "punsafe": sample.get("punsafe"),
321
+ "pwatermark": sample.get("pwatermark"),
322
+ },
323
+ ensure_ascii=False,
324
+ )
325
+ + "\n"
326
+ )
327
+ saved += 1
328
+ if saved >= args.count:
329
+ break
330
+
331
+ logger.info("Saved %d backgrounds to %s", saved, args.output_dir)
332
+
333
+
334
+ def _expand_parquet_files(data_files):
335
+ if isinstance(data_files, (list, tuple)):
336
+ return list(data_files)
337
+ if not data_files:
338
+ return []
339
+ if os.path.exists(data_files) and data_files.endswith(".parquet"):
340
+ return [data_files]
341
+ return sorted(glob.glob(data_files))
342
+
343
+
344
+ def _parquet_has_column(parquet_files, column_name):
345
+ if not column_name:
346
+ return False
347
+ for parquet_path in parquet_files:
348
+ parquet_file = pq.ParquetFile(parquet_path)
349
+ if column_name in parquet_file.schema.names:
350
+ return True
351
+ schema_arrow = getattr(parquet_file, "schema_arrow", None)
352
+ if schema_arrow is not None and column_name in schema_arrow.names:
353
+ return True
354
+ return False
355
+
356
+
357
+ def _has_img2dataset_parquet(image_root):
358
+ if not image_root or not os.path.exists(image_root):
359
+ return False
360
+ return bool(glob.glob(os.path.join(image_root, "*.parquet")))
361
+
362
+
363
+ def _prepare_download_parquet(
364
+ parquet_files,
365
+ output_dir,
366
+ count,
367
+ seed,
368
+ url_column,
369
+ text_column,
370
+ hash_column,
371
+ ):
372
+ os.makedirs(output_dir, exist_ok=True)
373
+ if len(parquet_files) == 1:
374
+ return parquet_files[0]
375
+ rng = random.Random(seed)
376
+ columns = [
377
+ url_column,
378
+ text_column,
379
+ hash_column,
380
+ "WIDTH",
381
+ "HEIGHT",
382
+ "aesthetic",
383
+ "punsafe",
384
+ "pwatermark",
385
+ ]
386
+ sampled = _reservoir_sample_parquet(
387
+ parquet_files=parquet_files,
388
+ target_count=count,
389
+ rng=rng,
390
+ columns=columns,
391
+ )
392
+ if not sampled:
393
+ raise ValueError("Failed to sample rows from parquet files.")
394
+ table = pa.Table.from_pylist(sampled)
395
+ out_path = os.path.join(output_dir, "laion_download_sample.parquet")
396
+ pq.write_table(table, out_path)
397
+ logger.info("Wrote sampled parquet list to %s", out_path)
398
+ return out_path
399
+
400
+
401
+ def _detect_image_extension(image):
402
+ fmt = (image.format or "").upper()
403
+ if fmt == "JPEG":
404
+ return "jpg"
405
+ if fmt == "PNG":
406
+ return "png"
407
+ if fmt == "WEBP":
408
+ return "webp"
409
+ return "jpg"
410
+
411
+
412
+ def _collect_existing_images(image_root, image_extensions):
413
+ if not image_root or not os.path.exists(image_root):
414
+ return {}
415
+ image_map = {}
416
+ for root, _, files in os.walk(image_root):
417
+ for name in files:
418
+ ext = os.path.splitext(name)[1].lower()
419
+ if ext in image_extensions:
420
+ stem = os.path.splitext(name)[0]
421
+ image_map[stem] = os.path.join(root, name)
422
+ return image_map
423
+
424
+
425
+ def _save_image_bytes(image_bytes, output_path):
426
+ try:
427
+ with Image.open(BytesIO(image_bytes)) as img:
428
+ ext = _detect_image_extension(img)
429
+ if ext == "jpg":
430
+ img = img.convert("RGB")
431
+ elif img.mode not in ("RGB", "RGBA"):
432
+ img = img.convert("RGBA")
433
+ output_path = os.path.splitext(output_path)[0] + f".{ext}"
434
+ img.save(output_path)
435
+ return output_path, img.size
436
+ except Exception as exc:
437
+ logger.warning("Failed to decode image bytes: %s", exc)
438
+ return None, None
439
+
440
+
441
+ def _iter_embedded_rows(
442
+ parquet_files,
443
+ id_column,
444
+ caption_column,
445
+ image_columns,
446
+ image_bytes_key,
447
+ image_path_key,
448
+ ):
449
+ columns = [id_column, caption_column] + list(image_columns)
450
+ for parquet_path in parquet_files:
451
+ parquet_file = pq.ParquetFile(parquet_path)
452
+ for batch in parquet_file.iter_batches(columns=columns, batch_size=256):
453
+ batch_dict = batch.to_pydict()
454
+ batch_len = len(batch)
455
+ for i in range(batch_len):
456
+ image_bytes = None
457
+ image_path = None
458
+ for image_column in image_columns:
459
+ image_struct = batch_dict.get(image_column, [None])[i] or {}
460
+ image_bytes = image_struct.get(image_bytes_key)
461
+ image_path = image_struct.get(image_path_key)
462
+ if image_bytes:
463
+ break
464
+ if not image_bytes:
465
+ continue
466
+ yield {
467
+ "id": batch_dict.get(id_column, [None])[i],
468
+ "caption": batch_dict.get(caption_column, [None])[i],
469
+ "bytes": image_bytes,
470
+ "path": image_path,
471
+ }
472
+
473
+
474
+ def download_embedded_images(
475
+ parquet_files,
476
+ image_root,
477
+ output_dir,
478
+ count,
479
+ seed,
480
+ sequential,
481
+ id_column,
482
+ caption_column,
483
+ image_columns,
484
+ image_bytes_key,
485
+ image_path_key,
486
+ image_extensions,
487
+ skip_existing,
488
+ progress_interval,
489
+ ):
490
+ os.makedirs(image_root, exist_ok=True)
491
+ rng = random.Random(seed)
492
+ selected_ids = None
493
+ if not sequential:
494
+ sampled = _reservoir_sample_parquet(
495
+ parquet_files=parquet_files,
496
+ target_count=count,
497
+ rng=rng,
498
+ columns=[id_column],
499
+ )
500
+ selected_ids = {
501
+ str(row.get(id_column))
502
+ for row in sampled
503
+ if row.get(id_column) is not None
504
+ }
505
+ if not selected_ids:
506
+ raise ValueError("Failed to sample IDs from parquet files.")
507
+
508
+ image_extensions = image_extensions or [".jpg", ".png", ".jpeg", ".webp"]
509
+ existing_map = _collect_existing_images(image_root, image_extensions) if skip_existing else {}
510
+ if existing_map and len(existing_map) >= count:
511
+ logger.info(
512
+ "Found %d existing images in %s (target=%d).",
513
+ len(existing_map),
514
+ image_root,
515
+ count,
516
+ )
517
+ metadata_rows = []
518
+ for row in _iter_embedded_rows(
519
+ parquet_files=parquet_files,
520
+ id_column=id_column,
521
+ caption_column=caption_column,
522
+ image_columns=image_columns,
523
+ image_bytes_key=image_bytes_key,
524
+ image_path_key=image_path_key,
525
+ ):
526
+ image_id = row.get("id")
527
+ if image_id is None:
528
+ continue
529
+ image_id = str(image_id)
530
+ if selected_ids is not None and image_id not in selected_ids:
531
+ continue
532
+ saved_path = None
533
+ size = None
534
+ if image_id in existing_map:
535
+ saved_path = existing_map[image_id]
536
+ size = _get_image_size(saved_path)
537
+ if saved_path is None:
538
+ shard_dir = image_id[:5] if len(image_id) >= 5 else image_id
539
+ target_dir = os.path.join(image_root, shard_dir)
540
+ os.makedirs(target_dir, exist_ok=True)
541
+ target_path = os.path.join(target_dir, image_id)
542
+ saved_path, size = _save_image_bytes(row["bytes"], target_path)
543
+ if not saved_path:
544
+ continue
545
+ width, height = size if size else (None, None)
546
+ metadata_rows.append(
547
+ {
548
+ "key": image_id,
549
+ "caption": row.get("caption"),
550
+ "status": "success",
551
+ "width": width,
552
+ "height": height,
553
+ }
554
+ )
555
+ if progress_interval and len(metadata_rows) % progress_interval == 0:
556
+ logger.info("Extracted %d/%d images...", len(metadata_rows), count)
557
+ if sequential and len(metadata_rows) >= count:
558
+ break
559
+ if selected_ids is not None and len(metadata_rows) >= len(selected_ids):
560
+ break
561
+
562
+ if not metadata_rows:
563
+ raise ValueError("No embedded images were extracted.")
564
+ meta_table = pa.Table.from_pylist(metadata_rows)
565
+ meta_path = os.path.join(image_root, "embedded_metadata.parquet")
566
+ pq.write_table(meta_table, meta_path)
567
+ logger.info("Wrote embedded metadata to %s", meta_path)
568
+
569
+
570
+ def _reservoir_sample_parquet(parquet_files, target_count, rng, columns):
571
+ sample = []
572
+ total_seen = 0
573
+ for parquet_path in parquet_files:
574
+ parquet_file = pq.ParquetFile(parquet_path)
575
+ for batch in parquet_file.iter_batches(columns=columns, batch_size=4096):
576
+ batch_dict = batch.to_pydict()
577
+ batch_len = len(batch)
578
+ for i in range(batch_len):
579
+ row = {col: batch_dict.get(col, [None])[i] for col in columns}
580
+ total_seen += 1
581
+ if len(sample) < target_count:
582
+ sample.append(row)
583
+ else:
584
+ j = rng.randint(0, total_seen - 1)
585
+ if j < target_count:
586
+ sample[j] = row
587
+ return sample
588
+
589
+
590
+ def _iter_img2dataset_rows(image_root):
591
+ parquet_files = sorted(glob.glob(os.path.join(image_root, "*.parquet")))
592
+ if not parquet_files:
593
+ return
594
+ columns = ["key", "caption", "status", "width", "height"]
595
+ for parquet_path in parquet_files:
596
+ parquet_file = pq.ParquetFile(parquet_path)
597
+ for batch in parquet_file.iter_batches(columns=columns, batch_size=4096):
598
+ batch_dict = batch.to_pydict()
599
+ batch_len = len(batch)
600
+ for i in range(batch_len):
601
+ status = batch_dict.get("status", [None])[i]
602
+ if status and status != "success":
603
+ continue
604
+ key = batch_dict.get("key", [None])[i]
605
+ caption = batch_dict.get("caption", [None])[i]
606
+ width = batch_dict.get("width", [None])[i]
607
+ height = batch_dict.get("height", [None])[i]
608
+ if key is None:
609
+ continue
610
+ key_str = str(key)
611
+ yield {
612
+ "id": key_str,
613
+ "caption": caption,
614
+ "width": width,
615
+ "height": height,
616
+ }
617
+
618
+
619
+ def _image_path_from_id(image_root, key_str, image_extensions):
620
+ if not key_str:
621
+ return None
622
+ shard_dir = key_str[:5]
623
+ for ext in image_extensions:
624
+ path = os.path.join(image_root, shard_dir, f"{key_str}{ext}")
625
+ if os.path.exists(path):
626
+ return path
627
+ return os.path.join(image_root, shard_dir, f"{key_str}.jpg")
628
+
629
+
630
+ def _round_up_multiple(value, multiple):
631
+ if multiple <= 1:
632
+ return int(value)
633
+ return int(((value + multiple - 1) // multiple) * multiple)
634
+
635
+
636
+ def _get_image_size(path):
637
+ try:
638
+ with Image.open(path) as img:
639
+ return img.size
640
+ except Exception as exc:
641
+ logger.warning("Failed to read image size for %s: %s", path, exc)
642
+ return None
643
+
644
+
645
+ def build_splits_from_img2dataset(
646
+ image_root,
647
+ output_dir,
648
+ train_count,
649
+ val_count,
650
+ test_count,
651
+ seed,
652
+ sequential=False,
653
+ allow_partial=False,
654
+ id_as_path=False,
655
+ image_extensions=None,
656
+ size_multiple=8,
657
+ ):
658
+ os.makedirs(output_dir, exist_ok=True)
659
+ total_needed = train_count + val_count + test_count
660
+ image_extensions = image_extensions or [".jpg", ".png", ".jpeg", ".webp"]
661
+ items = []
662
+ if sequential:
663
+ for row in _iter_img2dataset_rows(image_root):
664
+ items.append(row)
665
+ if len(items) >= total_needed:
666
+ break
667
+ else:
668
+ rng = random.Random(seed)
669
+ total_seen = 0
670
+ for row in _iter_img2dataset_rows(image_root):
671
+ total_seen += 1
672
+ if len(items) < total_needed:
673
+ items.append(row)
674
+ else:
675
+ j = rng.randint(0, total_seen - 1)
676
+ if j < total_needed:
677
+ items[j] = row
678
+ rng.shuffle(items)
679
+
680
+ if len(items) < total_needed:
681
+ if not allow_partial:
682
+ raise ValueError(
683
+ f"Only found {len(items)} matching images (needed {total_needed})."
684
+ )
685
+ logger.warning(
686
+ "Only found %d matching images (needed %d).",
687
+ len(items),
688
+ total_needed,
689
+ )
690
+
691
+ if id_as_path:
692
+ for item in items:
693
+ item["id"] = _image_path_from_id(image_root, item["id"], image_extensions)
694
+
695
+ train_items = items[:train_count]
696
+ val_items = items[train_count : train_count + val_count]
697
+ test_items = items[train_count + val_count : train_count + val_count + test_count]
698
+
699
+ def write_jsonl(path, rows):
700
+ with open(path, "w", encoding="utf-8") as f:
701
+ for row in rows:
702
+ image_path = row.get("path")
703
+ if not image_path:
704
+ image_id = row.get("id")
705
+ if image_id:
706
+ if os.path.isabs(image_id):
707
+ image_path = image_id
708
+ else:
709
+ image_path = _image_path_from_id(
710
+ image_root, image_id, image_extensions
711
+ )
712
+ if image_path:
713
+ row["path"] = image_path
714
+ size = _get_image_size(image_path)
715
+ if size:
716
+ width, height = size
717
+ else:
718
+ width = row.get("width")
719
+ height = row.get("height")
720
+ if width and height:
721
+ row["width"] = _round_up_multiple(int(width), size_multiple)
722
+ row["height"] = _round_up_multiple(int(height), size_multiple)
723
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
724
+
725
+ write_jsonl(os.path.join(output_dir, "train.jsonl"), train_items)
726
+ write_jsonl(os.path.join(output_dir, "val.jsonl"), val_items)
727
+ write_jsonl(os.path.join(output_dir, "test.jsonl"), test_items)
728
+
729
+ logger.info(
730
+ "Wrote splits to %s (train=%d, val=%d, test=%d)",
731
+ output_dir,
732
+ len(train_items),
733
+ len(val_items),
734
+ len(test_items),
735
+ )
736
+
737
+
738
+ def _scan_images(image_root, image_extensions):
739
+ if not image_root or not os.path.exists(image_root):
740
+ return {}
741
+ image_map = {}
742
+ for root, _, files in os.walk(image_root):
743
+ for name in files:
744
+ ext = os.path.splitext(name)[1].lower()
745
+ if ext in image_extensions:
746
+ stem = os.path.splitext(name)[0]
747
+ image_map[stem] = os.path.join(root, name)
748
+ return image_map
749
+
750
+
751
+ def _collect_metadata(
752
+ parquet_files,
753
+ image_map,
754
+ target_count,
755
+ url_column,
756
+ text_column,
757
+ hash_column,
758
+ ):
759
+ selected = []
760
+ hashes = set(image_map.keys())
761
+ if not hashes:
762
+ return selected
763
+ columns = [
764
+ hash_column,
765
+ url_column,
766
+ text_column,
767
+ "WIDTH",
768
+ "HEIGHT",
769
+ "aesthetic",
770
+ "punsafe",
771
+ "pwatermark",
772
+ ]
773
+ for parquet_path in parquet_files:
774
+ parquet_file = pq.ParquetFile(parquet_path)
775
+ for batch in parquet_file.iter_batches(columns=columns, batch_size=4096):
776
+ batch_dict = batch.to_pydict()
777
+ for i in range(len(batch)):
778
+ hash_value = batch_dict.get(hash_column, [None])[i]
779
+ if hash_value is None:
780
+ continue
781
+ hash_str = str(hash_value)
782
+ path = image_map.get(hash_str)
783
+ if not path:
784
+ continue
785
+ selected.append(
786
+ {
787
+ "file": path,
788
+ "url": batch_dict.get(url_column, [None])[i],
789
+ "text": batch_dict.get(text_column, [None])[i],
790
+ "width": batch_dict.get("WIDTH", [None])[i],
791
+ "height": batch_dict.get("HEIGHT", [None])[i],
792
+ "hash": hash_str,
793
+ "aesthetic": batch_dict.get("aesthetic", [None])[i],
794
+ "punsafe": batch_dict.get("punsafe", [None])[i],
795
+ "pwatermark": batch_dict.get("pwatermark", [None])[i],
796
+ }
797
+ )
798
+ if len(selected) >= target_count:
799
+ return selected
800
+ return selected
801
+
802
+
803
+ def build_splits(
804
+ data_files,
805
+ image_root,
806
+ image_extensions,
807
+ output_dir,
808
+ train_count,
809
+ val_count,
810
+ test_count,
811
+ seed,
812
+ url_column,
813
+ text_column,
814
+ hash_column,
815
+ sequential=False,
816
+ allow_partial=False,
817
+ size_multiple=8,
818
+ ):
819
+ os.makedirs(output_dir, exist_ok=True)
820
+ parquet_files = _expand_parquet_files(data_files)
821
+ if not parquet_files:
822
+ raise ValueError("No parquet files found. Check --data-files.")
823
+
824
+ image_map = _scan_images(image_root, image_extensions)
825
+ if not image_map:
826
+ raise ValueError("No images found in image_root.")
827
+
828
+ total_needed = train_count + val_count + test_count
829
+ logger.info(
830
+ "Collecting %d samples from %d parquet files (images=%d)",
831
+ total_needed,
832
+ len(parquet_files),
833
+ len(image_map),
834
+ )
835
+ items = _collect_metadata(
836
+ parquet_files=parquet_files,
837
+ image_map=image_map,
838
+ target_count=total_needed,
839
+ url_column=url_column,
840
+ text_column=text_column,
841
+ hash_column=hash_column,
842
+ )
843
+ if len(items) < total_needed:
844
+ if not allow_partial:
845
+ raise ValueError(
846
+ f"Only found {len(items)} matching images (needed {total_needed})."
847
+ )
848
+ logger.warning(
849
+ "Only found %d matching images (needed %d).",
850
+ len(items),
851
+ total_needed,
852
+ )
853
+
854
+ if not sequential:
855
+ rng = random.Random(seed)
856
+ rng.shuffle(items)
857
+ train_items = items[:train_count]
858
+ val_items = items[train_count : train_count + val_count]
859
+ test_items = items[train_count + val_count : train_count + val_count + test_count]
860
+
861
+ def write_jsonl(path, rows):
862
+ with open(path, "w", encoding="utf-8") as f:
863
+ for row in rows:
864
+ image_path = row.get("path") or row.get("file")
865
+ if image_path:
866
+ row["path"] = image_path
867
+ size = _get_image_size(image_path)
868
+ if size:
869
+ width, height = size
870
+ else:
871
+ width = row.get("width")
872
+ height = row.get("height")
873
+ if width and height:
874
+ row["width"] = _round_up_multiple(int(width), size_multiple)
875
+ row["height"] = _round_up_multiple(int(height), size_multiple)
876
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
877
+
878
+ write_jsonl(os.path.join(output_dir, "train.jsonl"), train_items)
879
+ write_jsonl(os.path.join(output_dir, "val.jsonl"), val_items)
880
+ write_jsonl(os.path.join(output_dir, "test.jsonl"), test_items)
881
+
882
+ logger.info(
883
+ "Wrote splits to %s (train=%d, val=%d, test=%d)",
884
+ output_dir,
885
+ len(train_items),
886
+ len(val_items),
887
+ len(test_items),
888
+ )
889
+
890
+
891
+ if __name__ == "__main__":
892
+ main()
893
+
894
+ '''
895
+ python -m tools.sample_backgrounds \
896
+ --download \
897
+ --count 20100 \
898
+ --build-splits \
899
+ --train-count 19000 \
900
+ --val-count 1000 \
901
+ --test-count 200 \
902
+ --data-files "/project/llmsvgen/share/data/kmw_layered_dataset/laion2B-en-aesthetic-image/*.parquet" \
903
+ --image-root "/project/llmsvgen/share/data/kmw_layered_dataset/laion2B-en-aesthetic-image" \
904
+ --output-dir "/project/llmsvgen/jinmin/SynLayers/data/laion2b_splits"
905
+
906
+ python -m tools.sample_backgrounds \
907
+ --download \
908
+ --build-splits \
909
+ --count 40200 \
910
+ --sequential \
911
+ --id-as-path \
912
+ --train-count 19000 \
913
+ --val-count 1000 \
914
+ --test-count 200 \
915
+ --data-files "/project/llmsvgen/share/data/kmw_layered_dataset/PrismLayersPro-image/data/*.parquet" \
916
+ --image-root "/project/llmsvgen/share/data/kmw_layered_dataset/PrismLayersPro-image/data/haolin/PrismLayersPro-image" \
917
+ --output-dir "/project/llmsvgen/jinmin/SynLayers/data/prismlayerspro_splits"
918
+ '''