|
import os |
|
import os.path as op |
|
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
|
|
|
|
|
|
|
|
|
assert op.exists(root), "Cannot find the executing root." |
|
assert op.basename(root) == "Recurrent-Parameter-Generation", \ |
|
f""" |
|
You need to rename the repository folder to "Recurrent-Parameter-Generation" manually. |
|
Because the whole project depends on this name. |
|
The file structure is as follow: |
|
└─Recurrent-Parameter-Generation |
|
├─dataset |
|
│ ├─cifar10_cnnmedium |
|
│ ├─...(total 21 folders) |
|
│ ├─__init__.py |
|
│ ├─config.json |
|
│ ├─dataset.py |
|
│ └─register.py |
|
├─model |
|
│ ├─__init__.py |
|
│ ├─denoiser.py |
|
│ ├─diffusion.py |
|
│ └─...(total 8 files) |
|
├─quick_start |
|
│ ├─set_configs.py |
|
│ └─auto_start.sh |
|
├─workspace |
|
│ ├─main |
|
│ ├─evaluate |
|
│ ├─...(total 6 folders) |
|
│ └─config.json |
|
├─README.md |
|
└─requirements.txt |
|
""" |
|
|
|
|
|
print("\n1. Set an \033[91mABSOLUTE\033[0m path to download your small dataset, such as CIFAR10 and CIFAR100") |
|
default_dataset_root = op.join(op.dirname(op.abspath(root)), 'Dataset') |
|
dataset_root = input(f"[{default_dataset_root} (default & \033[32mrecommanded\033[0m)]: ") or default_dataset_root |
|
print(f"\033[32mdataset_root is set to {dataset_root}\033[0m") |
|
|
|
|
|
print("\n2. Set the \033[91mABSOLUTE\033[0m path to your ImageNet1k dataset. " |
|
"\033[32m(Press ENTER if you don't want to use ImageNet1k)\033[0m") |
|
print("""The ImageNet1k dataset should be organized as follow: |
|
└─ImageNet1k |
|
├─train |
|
│ ├─n01443537 |
|
│ ├─n01484850 |
|
│ ├─n######## |
|
└─test |
|
├─n01443537 |
|
├─n01484850 |
|
└─n########""") |
|
imagenet_root = input(f"[None (default)]: ") |
|
if imagenet_root == "": |
|
print("\033[32mWe don't use ImageNet1k.\033[0m") |
|
imagenet_root_train = None |
|
imagenet_root_test = None |
|
else: |
|
print(f"\033[32mimagenet_root is set to {imagenet_root}\033[0m") |
|
imagenet_root_train = op.join(imagenet_root, "train") |
|
imagenet_root_test = op.join(imagenet_root, "test") |
|
assert op.exists(imagenet_root_train), f"{imagenet_root_train} is not existed." |
|
assert op.exists(imagenet_root_test), f"{imagenet_root_test} is not existed." |
|
|
|
|
|
print("\n3. Do you want to use wandb?") |
|
default_use_wandb = True |
|
use_wandb = input("[True (default & \033[32mrecommanded\033[0m)) / False]: ") |
|
use_wandb = default_use_wandb if use_wandb == "" else eval(use_wandb) |
|
print(f"\033[32muse_wandb is set to {use_wandb}\033[0m") |
|
if use_wandb: |
|
wandb_api_key = input("Set your wandb api key: ") |
|
assert wandb_api_key != "", "You need to set an API_KEY is you want to use wandb." |
|
|
|
|
|
|
|
|
|
print() |
|
import json |
|
from pprint import pprint |
|
|
|
|
|
print() |
|
with open(op.join(root, "dataset/config.json"), "r") as f: |
|
dataset_config = json.load(f) |
|
dataset_config.update({ |
|
"dataset_root": dataset_root, |
|
"imagenet_root": { |
|
"train": imagenet_root_train, |
|
"test": imagenet_root_test, |
|
}, |
|
}) |
|
with open(op.join(root, "dataset/config.json"), "w") as f: |
|
print("\033[32mUpdated dataset/config.json as follow:\033[0m") |
|
pprint(dataset_config) |
|
json.dump(dataset_config, f) |
|
|
|
|
|
print() |
|
with open(op.join(root, "workspace/config.json"), "r") as f: |
|
workspace_config = json.load(f) |
|
workspace_config.update({ |
|
"use_wandb": use_wandb, |
|
"wandb_api_key": globals().get("wandb_api_key", None), |
|
}) |
|
with open(op.join(root, "workspace/config.json"), "w") as f: |
|
print("\033[32mUpdated workspace/config.json as follow:\033[0m") |
|
pprint(workspace_config) |
|
json.dump(workspace_config, f) |
|
print() |