Doven
update code.
f7009b3
import os
import os.path as op
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1])
assert op.exists(root), "Cannot find the executing root."
assert op.basename(root) == "Recurrent-Parameter-Generation", \
f"""
You need to rename the repository folder to "Recurrent-Parameter-Generation" manually.
Because the whole project depends on this name.
The file structure is as follow:
└─Recurrent-Parameter-Generation
├─dataset
│ ├─cifar10_cnnmedium
│ ├─...(total 21 folders)
│ ├─__init__.py
│ ├─config.json
│ ├─dataset.py
│ └─register.py
├─model
│ ├─__init__.py
│ ├─denoiser.py
│ ├─diffusion.py
│ └─...(total 8 files)
├─quick_start
│ ├─set_configs.py
│ └─auto_start.sh
├─workspace
│ ├─main
│ ├─evaluate
│ ├─...(total 6 folders)
│ └─config.json
├─README.md
└─requirements.txt
"""
print("\n1. Set an \033[91mABSOLUTE\033[0m path to download your small dataset, such as CIFAR10 and CIFAR100")
default_dataset_root = op.join(op.dirname(op.abspath(root)), 'Dataset')
dataset_root = input(f"[{default_dataset_root} (default & \033[32mrecommanded\033[0m)]: ") or default_dataset_root
print(f"\033[32mdataset_root is set to {dataset_root}\033[0m")
print("\n2. Set the \033[91mABSOLUTE\033[0m path to your ImageNet1k dataset. "
"\033[32m(Press ENTER if you don't want to use ImageNet1k)\033[0m")
print("""The ImageNet1k dataset should be organized as follow:
└─ImageNet1k
├─train
│ ├─n01443537
│ ├─n01484850
│ ├─n########
└─test
├─n01443537
├─n01484850
└─n########""")
imagenet_root = input(f"[None (default)]: ")
if imagenet_root == "":
print("\033[32mWe don't use ImageNet1k.\033[0m")
imagenet_root_train = None
imagenet_root_test = None
else: # imagenet path is set
print(f"\033[32mimagenet_root is set to {imagenet_root}\033[0m")
imagenet_root_train = op.join(imagenet_root, "train")
imagenet_root_test = op.join(imagenet_root, "test")
assert op.exists(imagenet_root_train), f"{imagenet_root_train} is not existed."
assert op.exists(imagenet_root_test), f"{imagenet_root_test} is not existed."
print("\n3. Do you want to use wandb?")
default_use_wandb = True
use_wandb = input("[True (default & \033[32mrecommanded\033[0m)) / False]: ")
use_wandb = default_use_wandb if use_wandb == "" else eval(use_wandb)
print(f"\033[32muse_wandb is set to {use_wandb}\033[0m")
if use_wandb:
wandb_api_key = input("Set your wandb api key: ")
assert wandb_api_key != "", "You need to set an API_KEY is you want to use wandb."
print()
import json
from pprint import pprint
# dataset/config.json
print()
with open(op.join(root, "dataset/config.json"), "r") as f:
dataset_config = json.load(f)
dataset_config.update({
"dataset_root": dataset_root,
"imagenet_root": {
"train": imagenet_root_train,
"test": imagenet_root_test,
},
})
with open(op.join(root, "dataset/config.json"), "w") as f:
print("\033[32mUpdated dataset/config.json as follow:\033[0m")
pprint(dataset_config)
json.dump(dataset_config, f)
# workspace/config.json
print()
with open(op.join(root, "workspace/config.json"), "r") as f:
workspace_config = json.load(f)
workspace_config.update({
"use_wandb": use_wandb,
"wandb_api_key": globals().get("wandb_api_key", None),
})
with open(op.join(root, "workspace/config.json"), "w") as f:
print("\033[32mUpdated workspace/config.json as follow:\033[0m")
pprint(workspace_config)
json.dump(workspace_config, f)
print()