import json from hashlib import md5 from pathlib import Path import datasets from tqdm import tqdm from realfake.utils import Args, inject_args class DownloadParams(Args): output_dir: Path subset: str = "2m_first_1k" @inject_args def main(params: DownloadParams) -> None: dataset = datasets.load_dataset("poloclub/diffusiondb", params.subset, split="train", streaming=True) output_dir = params.output_dir/params.subset output_dir.mkdir(parents=True, exist_ok=True) with (output_dir/"test.jsonl").open("w") as fp: for item in tqdm(dataset, total=None): image_id = md5((item["prompt"] + str(item["seed"])).encode()).hexdigest() filename = output_dir/f"{image_id}.png" if not filename.exists(): item["image"].save(filename) record = {"path": str(filename), "label": "fake", "class": None, "valid": False} fp.write(f"{json.dumps(record)}\n") print(f"Saved records to {output_dir}") if __name__ == "__main__": main()