realfake / realfake /bin /diffusion_db.py
devforfu
Init
ea847ad
import json
from hashlib import md5
from pathlib import Path
import datasets
from tqdm import tqdm
from realfake.utils import Args, inject_args
class DownloadParams(Args):
output_dir: Path
subset: str = "2m_first_1k"
@inject_args
def main(params: DownloadParams) -> None:
dataset = datasets.load_dataset("poloclub/diffusiondb", params.subset, split="train", streaming=True)
output_dir = params.output_dir/params.subset
output_dir.mkdir(parents=True, exist_ok=True)
with (output_dir/"test.jsonl").open("w") as fp:
for item in tqdm(dataset, total=None):
image_id = md5((item["prompt"] + str(item["seed"])).encode()).hexdigest()
filename = output_dir/f"{image_id}.png"
if not filename.exists():
item["image"].save(filename)
record = {"path": str(filename), "label": "fake", "class": None, "valid": False}
fp.write(f"{json.dumps(record)}\n")
print(f"Saved records to {output_dir}")
if __name__ == "__main__":
main()