File size: 3,981 Bytes
f7009b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import os.path as op
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1])




assert op.exists(root), "Cannot find the executing root."
assert op.basename(root) == "Recurrent-Parameter-Generation", \
    f"""
    You need to rename the repository folder to "Recurrent-Parameter-Generation" manually.
    Because the whole project depends on this name.
    The file structure is as follow:
    └─Recurrent-Parameter-Generation
        ├─dataset
        │   ├─cifar10_cnnmedium
        │   ├─...(total 21 folders)
        │   ├─__init__.py
        │   ├─config.json
        │   ├─dataset.py
        │   └─register.py
        ├─model
        │   ├─__init__.py
        │   ├─denoiser.py
        │   ├─diffusion.py
        │   └─...(total 8 files)
        ├─quick_start
        │   ├─set_configs.py
        │   └─auto_start.sh
        ├─workspace
        │   ├─main
        │   ├─evaluate
        │   ├─...(total 6 folders)
        │   └─config.json
        ├─README.md
        └─requirements.txt
    """


print("\n1. Set an \033[91mABSOLUTE\033[0m path to download your small dataset, such as CIFAR10 and CIFAR100")
default_dataset_root = op.join(op.dirname(op.abspath(root)), 'Dataset')
dataset_root = input(f"[{default_dataset_root} (default & \033[32mrecommanded\033[0m)]: ") or default_dataset_root
print(f"\033[32mdataset_root is set to {dataset_root}\033[0m")


print("\n2. Set the \033[91mABSOLUTE\033[0m path to your ImageNet1k dataset. "
      "\033[32m(Press ENTER if you don't want to use ImageNet1k)\033[0m")
print("""The ImageNet1k dataset should be organized as follow:
└─ImageNet1k
   ├─train
   │   ├─n01443537
   │   ├─n01484850
   │   ├─n########
   └─test
       ├─n01443537
       ├─n01484850
       └─n########""")
imagenet_root = input(f"[None (default)]: ")
if imagenet_root == "":
    print("\033[32mWe don't use ImageNet1k.\033[0m")
    imagenet_root_train = None
    imagenet_root_test = None
else:  # imagenet path is set
    print(f"\033[32mimagenet_root is set to {imagenet_root}\033[0m")
    imagenet_root_train = op.join(imagenet_root, "train")
    imagenet_root_test = op.join(imagenet_root, "test")
    assert op.exists(imagenet_root_train), f"{imagenet_root_train} is not existed."
    assert op.exists(imagenet_root_test), f"{imagenet_root_test} is not existed."


print("\n3. Do you want to use wandb?")
default_use_wandb = True
use_wandb = input("[True (default & \033[32mrecommanded\033[0m)) / False]: ")
use_wandb = default_use_wandb if use_wandb == "" else eval(use_wandb)
print(f"\033[32muse_wandb is set to {use_wandb}\033[0m")
if use_wandb:
    wandb_api_key = input("Set your wandb api key: ")
    assert wandb_api_key != "", "You need to set an API_KEY is you want to use wandb."




print()
import json
from pprint import pprint

# dataset/config.json
print()
with open(op.join(root, "dataset/config.json"), "r") as f:
    dataset_config = json.load(f)
    dataset_config.update({
        "dataset_root": dataset_root,
        "imagenet_root": {
            "train": imagenet_root_train,
            "test": imagenet_root_test,
        },
    })
with open(op.join(root, "dataset/config.json"), "w") as f:
    print("\033[32mUpdated dataset/config.json as follow:\033[0m")
    pprint(dataset_config)
    json.dump(dataset_config, f)

# workspace/config.json
print()
with open(op.join(root, "workspace/config.json"), "r") as f:
    workspace_config = json.load(f)
    workspace_config.update({
        "use_wandb": use_wandb,
        "wandb_api_key": globals().get("wandb_api_key", None),
    })
with open(op.join(root, "workspace/config.json"), "w") as f:
    print("\033[32mUpdated workspace/config.json as follow:\033[0m")
    pprint(workspace_config)
    json.dump(workspace_config, f)
print()