File size: 1,819 Bytes
ea847ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Creates a meta-data file by combining the information from directory structure.
"""
import json
from pathlib import Path

from pydantic import Field

from realfake.config import IMAGE_FORMATS
from realfake.utils import inject_args, Args


class CreateMetadataArgs(Args):
    root_dir: Path
    datasets: str = Field(..., help="Comma-separated list of datasets to include in the meta-data file")
    jsonl_file: Path = Field(..., help="Path to the output JSONL file")


@inject_args
def main(args: CreateMetadataArgs) -> None:
    datasets = args.datasets.split(",")
    records = []
    for dataset in datasets:
        label = "real" if dataset.startswith("real") else "fake"
        dirpath = args.root_dir/dataset
        assert dirpath.exists(), f"dataset dir does not exist: {dirpath}"
        records.extend((parse_imagenet if "imagenet" in dataset else parse_flat)(dirpath, label))
    with open(args.jsonl_file, "w") as f:
        for record in records:
            f.write(json.dumps(record) + "\n")


def parse_imagenet(dirpath: Path, label: str) -> list:
    records = []
    for classdir in dirpath.iterdir():
        assert classdir.is_dir(), f"class directory is not a directory: {classdir}"
        for fn in classdir.iterdir():
            if fn.suffix.lower() in IMAGE_FORMATS:
                records.append({"path": str(fn), "label": label, "class": classdir.name})
            else:
                print("Not an image file:", fn)
    return records


def parse_flat(dirpath: Path, label: str) -> list:
    records = []
    for fn in dirpath.iterdir():
        if fn.suffix.lower() in IMAGE_FORMATS:
            records.append({"path": str(fn), "label": label, "class": None})
        else:
            print("Not an image file:", fn)
    return records



if __name__ == "__main__":
    main()