sida / cc3m_render.py
xiangzai's picture
Add files using upload-large-folder tool
7803bdf verified
#!/usr/bin/env python
# coding: utf-8
"""
在 `data_root/` 下已经有 `train/` 和 `validation/` 两个文件夹时:
分别在这两个文件夹内生成对应的 `metadata.jsonl`,不复制任何图片。
`metadata.jsonl` 每行格式:
{"file_name": "subdir/000026831.jpg", "caption": "..."}
其中 `file_name` 是相对当前 split 目录(train/ 或 validation/)的路径。
"""
import argparse
import json
import os
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from pathlib import Path
from typing import Optional, Tuple
from tqdm import tqdm
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate per-split metadata.jsonl for imagefolder (no copy)")
parser.add_argument(
"--data_root",
type=str,
default="/gemini/space/hsd/project/dataset/cc3m-wds",
help="数据根目录(必须包含 train/ 和 validation/)",
)
parser.add_argument(
"--jsonl_name",
type=str,
default="metadata.jsonl",
help="每个 split 下生成的 jsonl 文件名(默认 metadata.jsonl)",
)
parser.add_argument(
"--use_txt_caption",
action="store_true",
default=True,
help="优先读取同名 .txt 作为 caption(默认开启),否则回落到 .json",
)
parser.add_argument(
"--num_workers",
type=int,
default=32,
help="线程数(I/O 密集型建议 8~64 之间按机器调整)",
)
parser.add_argument(
"--max_images",
type=int,
default=None,
help="每个 split 最多处理多少张图片(None 表示全部,调试可用)",
)
return parser.parse_args()
def read_caption_from_txt(txt_path: Path) -> Optional[str]:
if not txt_path.exists():
return None
try:
with txt_path.open("r", encoding="utf-8") as f:
caption = f.read().strip()
return caption or None
except Exception:
return None
def read_caption_from_json(json_path: Path) -> Optional[str]:
if not json_path.exists():
return None
try:
with json_path.open("r", encoding="utf-8") as f:
data = json.load(f)
for key in ["caption", "text", "description"]:
if key in data and isinstance(data[key], str) and data[key].strip():
return data[key].strip()
except Exception:
return None
return None
def main() -> None:
args = parse_args()
data_root = Path(args.data_root).resolve()
if not data_root.exists():
raise FileNotFoundError(f"数据根目录不存在:{data_root}")
splits = [("train", data_root / "train"), ("validation", data_root / "validation")]
for split_name, split_dir in splits:
if not split_dir.exists():
raise FileNotFoundError(f"缺少目录:{split_dir}(需要 train/ 和 validation/)")
def iter_images(split_dir: Path):
for root, _dirs, files in os.walk(split_dir):
for name in files:
if name.lower().endswith((".jpg", ".jpeg", ".png")):
yield Path(root) / name
def process_one(img_path: Path, split_dir: Path) -> Optional[Tuple[str, str]]:
txt_path = img_path.with_suffix(".txt")
json_path = img_path.with_suffix(".json")
caption = None
if args.use_txt_caption:
caption = read_caption_from_txt(txt_path)
if caption is None:
caption = read_caption_from_json(json_path)
else:
caption = read_caption_from_json(json_path)
if caption is None:
caption = read_caption_from_txt(txt_path)
if caption is None:
return None
rel = img_path.relative_to(split_dir)
return str(rel).replace(os.sep, "/"), caption
for split_name, split_dir in splits:
jsonl_path = split_dir / args.jsonl_name
img_iter = iter_images(split_dir)
if args.max_images is not None:
img_iter = islice(img_iter, args.max_images)
# tqdm 需要可迭代对象,这里不预先收集列表以节省内存
# 进度条显示 processed 数量(total 可能未知)
def _task_iter():
for p in img_iter:
yield p
written = 0
with jsonl_path.open("w", encoding="utf-8") as f, ThreadPoolExecutor(max_workers=args.num_workers) as ex:
# executor.map 保持输入顺序;tqdm 显示处理进度
for result in tqdm(
ex.map(lambda p: process_one(p, split_dir), _task_iter()),
desc=f"[{split_name}] Processing",
):
if result is None:
continue
file_name, caption = result
f.write(json.dumps({"file_name": file_name, "caption": caption}, ensure_ascii=False) + "\n")
written += 1
print(f"{split_name}: 写入 {written} 条 -> {jsonl_path}")
if __name__ == "__main__":
main()
# nohup python cc3m_render.py > cc3m_render.log 2>&1 &