|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import toml |
|
import sys |
|
from pathlib import Path |
|
from collections import defaultdict, Counter |
|
|
|
def update_config(root_dir): |
|
root_dir = Path(root_dir).resolve() |
|
config_path = root_dir / "config.toml" |
|
config = toml.load(config_path) |
|
stats = defaultdict(Counter) |
|
|
|
new_subsets = [] |
|
for dataset_path in root_dir.iterdir(): |
|
if not dataset_path.is_dir() or dataset_path.name[0] == '.': |
|
continue |
|
for subset_path in dataset_path.iterdir(): |
|
subset_name = subset_path.name |
|
subset_path = dataset_path / subset_path |
|
if not subset_path.is_dir() or subset_name[0] == '.': |
|
continue |
|
|
|
|
|
try: |
|
num_repeats = int(subset_name.partition('_')[0]) |
|
except ValueError: |
|
num_repeats = 1 |
|
new_subsets.append({ |
|
"image_dir": str(subset_path), |
|
"num_repeats": num_repeats |
|
}) |
|
|
|
|
|
|
|
data_files = defaultdict(set) |
|
for file in subset_path.iterdir(): |
|
ext = file.suffix |
|
if ext not in {'.txt', '.tags', '.caption', '.txt', '.jxl', '.jpg', '.jpeg', '.png', '.json'}: |
|
continue |
|
stem = file.stem.partition('.')[0] |
|
if stem == 'sample-prompts': |
|
continue |
|
data_files[stem].add(ext) |
|
|
|
subset_stats = stats[subset_path] |
|
for stem, exts in data_files.items(): |
|
has_caption = bool({'.txt', '.caption', 'caption', '.tags'} & exts) |
|
has_image = bool({'.jpg', '.jpeg', '.png', '.jxl'} & exts) |
|
|
|
if has_caption and has_image: |
|
subset_stats["captioned"] += 1 |
|
elif has_image: |
|
subset_stats["no_caption"] += 1 |
|
elif has_caption: |
|
subset_stats["orphans"] += 1 |
|
if 'DELETE_ORPHANS' in os.environ: |
|
print(f"Deleting orphan {subset_path / f'{stem}{ext}'}") |
|
if not 'DEBUG' in os.environ: |
|
for ext in exts: |
|
(subset_path / f"{stem}{ext}").unlink() |
|
raise NotImplementedError("UNFINISHED DO NOT USE") |
|
else: |
|
if '.toml' not in exts: |
|
for ext in exts: |
|
subset_stats[ext] += 1 |
|
|
|
|
|
config["datasets"][0]["subsets"] = new_subsets |
|
|
|
if "DEBUG" in os.environ: |
|
print(toml.dumps(config)) |
|
else: |
|
with open(config_path, "w") as f: |
|
toml.dump(config, f) |
|
|
|
return stats |
|
|
|
if __name__ == "__main__": |
|
if len(sys.argv) < 2: |
|
print("Usage: [DEBUG=1] [DELETE_ORPHANS=1] python script.py <ROOT_DIR>") |
|
sys.exit(1) |
|
|
|
root_dir = sys.argv[1] |
|
stats = update_config(root_dir) |
|
|
|
for subset, subset_stats in sorted(stats.items(), key=lambda x: x[0]): |
|
print(subset, dict(subset_stats)) |
|
|