|
import json |
|
from hashlib import md5 |
|
from pathlib import Path |
|
|
|
import datasets |
|
from tqdm import tqdm |
|
|
|
from realfake.utils import Args, inject_args |
|
|
|
|
|
class DownloadParams(Args): |
|
output_dir: Path |
|
subset: str = "2m_first_1k" |
|
|
|
|
|
@inject_args |
|
def main(params: DownloadParams) -> None: |
|
dataset = datasets.load_dataset("poloclub/diffusiondb", params.subset, split="train", streaming=True) |
|
|
|
output_dir = params.output_dir/params.subset |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
with (output_dir/"test.jsonl").open("w") as fp: |
|
for item in tqdm(dataset, total=None): |
|
image_id = md5((item["prompt"] + str(item["seed"])).encode()).hexdigest() |
|
filename = output_dir/f"{image_id}.png" |
|
if not filename.exists(): |
|
item["image"].save(filename) |
|
record = {"path": str(filename), "label": "fake", "class": None, "valid": False} |
|
fp.write(f"{json.dumps(record)}\n") |
|
|
|
print(f"Saved records to {output_dir}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|