realfake / realfake /bin /unpack_diffusion_db.py
devforfu
Init
ea847ad
import json
import zipfile
from itertools import chain
from pathlib import Path
from joblib import Parallel, delayed
from realfake.utils import get_user_name, inject_args, Args
class UnpackParams(Args):
meta_file: Path
jsonl_file: Path
num_workers: int = 16
def unpack(zip_path: Path, output_dir: Path):
print("extracting", zip_path)
with zipfile.ZipFile(zip_path, "r") as arch:
paths = [str(output_dir/fn) for fn in arch.namelist() if fn.endswith(".png")]
arch.extractall(output_dir)
return paths
@inject_args
def main(params: UnpackParams) -> None:
subset_name = params.meta_file.stem
output_dir = Path(f"/fsx/{get_user_name()}/data/fake_{subset_name}")
output_dir.mkdir(parents=True, exist_ok=True)
meta = json.loads(params.meta_file.read_text())
with Parallel(n_jobs=params.num_workers, verbose=100) as parallel:
results = parallel(delayed(unpack)(Path(m["path"]), output_dir) for m in meta if m["ok"])
records = [
{"path": str(fn), "label": "fake", "class": None, "valid": None}
for fn in chain.from_iterable(results)
]
with params.jsonl_file.open("w") as fp:
for record in records:
fp.write(json.dumps(record) + "\n")
if __name__ == "__main__":
main()