climateGAN / sbatch.py
vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
import datetime
import itertools
import os
import re
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
import numpy as np
import yaml
def flatten_conf(conf, to={}, parents=[]):
"""
Flattens a configuration dict: nested dictionaries are flattened
as key1.key2.key3 = value
conf.yaml:
```yaml
a: 1
b:
c: 2
d:
e: 3
g:
sample: sequential
from: [4, 5]
```
Is flattened to
{
"a": 1,
"b.c": 2,
"b.d.e": 3,
"b.g": {
"sample": "sequential",
"from": [4, 5]
}
}
Does not affect sampling dicts.
Args:
conf (dict): the configuration to flatten
new (dict, optional): the target flatenned dict. Defaults to {}.
parents (list, optional): a final value's list of parents. Defaults to [].
"""
for k, v in conf.items():
if isinstance(v, dict) and "sample" not in v:
flatten_conf(v, to, parents + [k])
else:
new_k = ".".join([str(p) for p in parents + [k]])
to[new_k] = v
def env_to_path(path):
"""Transorms an environment variable mention in a json
into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds
Args:
path (str): path potentially containing the env variable
"""
path_elements = path.split("/")
new_path = []
for el in path_elements:
if "$" in el:
new_path.append(os.environ[el.replace("$", "")])
else:
new_path.append(el)
return "/".join(new_path)
class C:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
ITALIC = "\33[3m"
BEIGE = "\33[36m"
def escape_path(path):
p = str(path)
return p.replace(" ", "\ ").replace("(", "\(").replace(")", "\)") # noqa: W605
def warn(*args, **kwargs):
print("{}{}{}".format(C.WARNING, " ".join(args), C.ENDC), **kwargs)
def parse_jobID(command_output):
"""
get job id from successful sbatch command output like
`Submitted batch job 599583`
Args:
command_output (str): sbatch command's output
Returns:
int: the slurm job's ID
"""
command_output = command_output.strip()
if isinstance(command_output, str):
if "Submitted batch job" in command_output:
return int(command_output.split()[-1])
return -1
def now():
return str(datetime.datetime.now()).replace(" ", "_")
def cols():
try:
col = os.get_terminal_size().columns
except Exception:
col = 50
return col
def print_box(txt):
if not txt:
txt = "{}{}ERROR ⇪{}".format(C.BOLD, C.FAIL, C.ENDC)
lt = 7
else:
lt = len(txt)
nlt = lt + 12
txt = "|" + " " * 5 + txt + " " * 5 + "|"
line = "-" * nlt
empty = "|" + " " * (nlt - 2) + "|"
print(line)
print(empty)
print(txt)
print(empty)
print(line)
def print_header(idx):
b = C.BOLD
bl = C.OKBLUE
e = C.ENDC
char = "≡"
c = cols()
txt = " " * 20
txt += f"{b}{bl}Run {idx}{e}"
txt += " " * 20
ln = len(txt) - len(b) - len(bl) - len(e)
t = int(np.floor((c - ln) / 2))
tt = int(np.ceil((c - ln) / 2))
print(char * c)
print(char * t + " " * ln + char * tt)
print(char * t + txt + char * tt)
print(char * t + " " * ln + char * tt)
print(char * c)
def print_footer():
c = cols()
char = "﹎"
print()
print(char * (c // len(char)))
print()
print(" " * (c // 2) + "•" + " " * (c - c // 2 - 1))
print()
def extend_summary(summary, tmp_train_args_dict, tmp_template_dict, exclude=[]):
exclude = set(exclude)
if summary is None:
summary = defaultdict(list)
for k, v in tmp_template_dict.items():
if k not in exclude:
summary[k].append(v)
for k, v in tmp_train_args_dict.items():
if k not in exclude:
if isinstance(v, list):
v = str(v)
summary[k].append(v)
return summary
def search_summary_table(summary, summary_dir=None):
# filter out constant values
summary = {k: v for k, v in summary.items() if len(set(v)) > 1}
# if everything is constant: no summary
if not summary:
return None, None
# find number of searches
n_searches = len(list(summary.values())[0])
# print section title
print(
"{}{}{}Varying values across {} experiments:{}\n".format(
C.OKBLUE,
C.BOLD,
C.UNDERLINE,
n_searches,
C.ENDC,
)
)
# first column holds the Exp. number
first_col = {
"len": 8, # length of a column, to split columns according to terminal width
"str": ["| Exp. |", "|:----:|"]
+ [
"| {0:^{1}} |".format(i, 4) for i in range(n_searches)
], # list of values to print
}
print_columns = [[first_col]]
file_columns = [first_col]
for k in sorted(summary.keys()):
v = summary[k]
col_title = f" {k} |"
col_blank_line = f":{'-' * len(k)}-|"
col_values = [
" {0:{1}} |".format(
crop_string(
str(crop_float(v[idx], min([5, len(k) - 2]))), len(k)
), # crop floats and long strings
len(k),
)
for idx in range(len(v))
]
# create column object
col = {"len": len(k) + 3, "str": [col_title, col_blank_line] + col_values}
# if adding a new column would overflow the terminal and mess up printing, start
# new set of columns
if sum(c["len"] for c in print_columns[-1]) + col["len"] >= cols():
print_columns.append([first_col])
# store current column to latest group of columns
print_columns[-1].append(col)
file_columns.append(col)
print_table = ""
# print each column group individually
for colgroup in print_columns:
# print columns line by line
for i in range(n_searches + 2):
# get value of column for current line i
for col in colgroup:
print_table += col["str"][i]
# next line for current columns
print_table += "\n"
# new lines for new column group
print_table += "\n"
file_table = ""
for i in range(n_searches + 2):
# get value of column for current line i
for col in file_columns:
file_table += col["str"][i]
# next line for current columns
file_table += "\n"
summary_path = None
if summary_dir is not None:
summary_path = summary_dir / (now() + ".md")
with summary_path.open("w") as f:
f.write(file_table.strip())
return print_table, summary_path
def clean_arg(v):
"""
chain cleaning function
Args:
v (any): arg to pass to train.py
Returns:
str: parsed value to string
"""
return stringify_list(crop_float(quote_string(resolve_env(v))))
def resolve_env(v):
"""
resolve env variables in paths
Args:
v (any): arg to pass to train.py
Returns:
str: try and resolve an env variable
"""
if isinstance(v, str):
try:
if "$" in v:
if "/" in v:
v = env_to_path(v)
else:
_v = os.environ.get(v)
if _v is not None:
v = _v
except Exception:
pass
return v
def stringify_list(v):
"""
Stringify list (with double quotes) so that it can be passed a an argument
to train.py's hydra command-line parsing
Args:
v (any): value to clean
Returns:
any: type of v, str if v was a list
"""
if isinstance(v, list):
return '"{}"'.format(str(v).replace('"', "'"))
if isinstance(v, str):
if v.startswith("[") and v.endswith("]"):
return f'"{v}"'
return v
def quote_string(v):
"""
Add double quotes around string if it contains a " " or an =
Args:
v (any): value to clean
Returns:
any: type of v, quoted if v is a string with " " or =
"""
if isinstance(v, str):
if " " in v or "=" in v:
return f'"{v}"'
return v
def crop_float(v, k=5):
"""
If v is a float, crop precision to 5 digits and return v as a str
Args:
v (any): value to crop if float
Returns:
any: cropped float as str if v is a float, original v otherwise
"""
if isinstance(v, float):
return "{0:.{1}g}".format(v, k)
return v
def compute_n_search(conf):
"""
Compute the number of searchs to do if using -1 as n_search and using
cartesian or sequential search
Args:
conf (dict): experimental configuration
Returns:
int: size of the cartesian product or length of longest sequential field
"""
samples = defaultdict(list)
for k, v in conf.items():
if not isinstance(v, dict) or "sample" not in v:
continue
samples[v["sample"]].append(v)
totals = []
if "cartesian" in samples:
total = 1
for s in samples["cartesian"]:
total *= len(s["from"])
totals.append(total)
if "sequential" in samples:
total = max(map(len, [s["from"] for s in samples["sequential"]]))
totals.append(total)
if totals:
return max(totals)
raise ValueError(
"Used n_search=-1 without any field being 'cartesian' or 'sequential'"
)
def crop_string(s, k=10):
if len(s) <= k:
return s
else:
return s[: k - 2] + ".."
def sample_param(sample_dict):
"""sample a value (hyperparameter) from the instruction in the
sample dict:
{
"sample": "range | list",
"from": [min, max, step] | [v0, v1, v2 etc.]
}
if range, as np.arange is used, "from" MUST be a list, but may contain
only 1 (=min) or 2 (min and max) values, not necessarily 3
Args:
sample_dict (dict): instructions to sample a value
Returns:
scalar: sampled value
"""
if not isinstance(sample_dict, dict) or "sample" not in sample_dict:
return sample_dict
if sample_dict["sample"] == "cartesian":
assert isinstance(
sample_dict["from"], list
), "{}'s `from` field MUST be a list, found {}".format(
sample_dict["sample"], sample_dict["from"]
)
return "__cartesian__"
if sample_dict["sample"] == "sequential":
assert isinstance(
sample_dict["from"], list
), "{}'s `from` field MUST be a list, found {}".format(
sample_dict["sample"], sample_dict["from"]
)
return "__sequential__"
if sample_dict["sample"] == "range":
return np.random.choice(np.arange(*sample_dict["from"]))
if sample_dict["sample"] == "list":
return np.random.choice(sample_dict["from"])
if sample_dict["sample"] == "uniform":
return np.random.uniform(*sample_dict["from"])
raise ValueError("Unknown sample type in dict " + str(sample_dict))
def sample_sequentials(sequential_keys, exp, idx):
"""
Samples sequentially from the "from" values specified in each key of the
experimental configuration which have sample == "sequential"
Unlike `cartesian` sampling, `sequential` sampling iterates *independently*
over each keys
Args:
sequential_keys (list): keys to be sampled sequentially
exp (dict): experimental config
idx (int): index of the current sample
Returns:
conf: sampled dict
"""
conf = {}
for k in sequential_keys:
v = exp[k]["from"]
conf[k] = v[idx % len(v)]
return conf
def sample_cartesians(cartesian_keys, exp, idx):
"""
Returns the `idx`th item in the cartesian product of all cartesian keys to
be sampled.
Args:
cartesian_keys (list): keys in the experimental configuration that are to
be used in the full cartesian product
exp (dict): experimental configuration
idx (int): index of the current sample
Returns:
dict: sampled point in the cartesian space (with keys = cartesian_keys)
"""
conf = {}
cartesian_values = [exp[key]["from"] for key in cartesian_keys]
product = list(itertools.product(*cartesian_values))
for k, v in zip(cartesian_keys, product[idx % len(product)]):
conf[k] = v
return conf
def resolve(hp_conf, nb):
"""
Samples parameters parametrized in `exp`: should be a dict with
values which fit `sample_params(dic)`'s API
Args:
exp (dict): experiment's parametrization
nb (int): number of experiments to sample
Returns:
dict: sampled configuration
"""
if nb == -1:
nb = compute_n_search(hp_conf)
confs = []
for idx in range(nb):
conf = {}
cartesians = []
sequentials = []
for k, v in hp_conf.items():
candidate = sample_param(v)
if candidate == "__cartesian__":
cartesians.append(k)
elif candidate == "__sequential__":
sequentials.append(k)
else:
conf[k] = candidate
if sequentials:
conf.update(sample_sequentials(sequentials, hp_conf, idx))
if cartesians:
conf.update(sample_cartesians(cartesians, hp_conf, idx))
confs.append(conf)
return confs
def get_template_params(template):
"""
extract args in template str as {arg}
Args:
template (str): sbatch template string
Returns:
list(str): Args required to format the template string
"""
return map(
lambda s: s.replace("{", "").replace("}", ""),
re.findall("\{.*?\}", template), # noqa: W605
)
def read_exp_conf(name):
"""
Read hp search configuration from shared/experiment/
specified with or without the .yaml extension
Args:
name (str): name of the template to find in shared/experiment/
Returns:
Tuple(Path, dict): file path and loaded dict
"""
if ".yaml" not in name:
name += ".yaml"
paths = []
dirs = ["shared", "config"]
for d in dirs:
path = Path(__file__).parent / d / "experiment" / name
if path.exists():
paths.append(path.resolve())
if len(paths) == 0:
failed = [Path(__file__).parent / d / "experiment" for d in dirs]
s = "Could not find search config {} in :\n".format(name)
for fd in failed:
s += str(fd) + "\nAvailable:\n"
for ym in fd.glob("*.yaml"):
s += " " + ym.name + "\n"
raise ValueError(s)
if len(paths) == 2:
print(
"Warning: found 2 relevant files for search config:\n{}".format(
"\n".join(paths)
)
)
print("Using {}".format(paths[-1]))
with paths[-1].open("r") as f:
conf = yaml.safe_load(f)
flat_conf = {}
flatten_conf(conf, to=flat_conf)
return (paths[-1], flat_conf)
def read_template(name):
"""
Read template from shared/template/ specified with or without the .sh extension
Args:
name (str): name of the template to find in shared/template/
Returns:
str: file's content as 1 string
"""
if ".sh" not in name:
name += ".sh"
paths = []
dirs = ["shared", "config"]
for d in dirs:
path = Path(__file__).parent / d / "template" / name
if path.exists():
paths.append(path)
if len(paths) == 0:
failed = [Path(__file__).parent / d / "template" for d in dirs]
s = "Could not find template {} in :\n".format(name)
for fd in failed:
s += str(fd) + "\nAvailable:\n"
for ym in fd.glob("*.sh"):
s += " " + ym.name + "\n"
raise ValueError(s)
if len(paths) == 2:
print("Warning: found 2 relevant template files:\n{}".format("\n".join(paths)))
print("Using {}".format(paths[-1]))
with paths[-1].open("r") as f:
return f.read()
def is_sampled(key, conf):
"""
Is a key sampled or constant? Returns true if conf is empty
Args:
key (str): key to check
conf (dict): hyper parameter search configuration dict
Returns:
bool: key is sampled?
"""
return not conf or (
key in conf and isinstance(conf[key], dict) and "sample" in conf[key]
)
if __name__ == "__main__":
"""
Notes:
* Must provide template name as template=name
* `name`.sh should be in shared/template/
"""
# -------------------------------
# ----- Default Variables -----
# -------------------------------
args = sys.argv[1:]
command_output = ""
user = os.environ.get("USER")
home = os.environ.get("HOME")
exp_conf = {}
dev = False
escape = False
verbose = False
template_name = None
hp_exp_name = None
hp_search_nb = None
exp_path = None
resume = None
force_sbatchs = False
sbatch_base = Path(home) / "climategan_sbatchs"
summary_dir = Path(home) / "climategan_exp_summaries"
hp_search_private = set(["n_search", "template", "search", "summary_dir"])
sbatch_path = "hash"
# --------------------------
# ----- Sanity Check -----
# --------------------------
for arg in args:
if "=" not in arg or " = " in arg:
raise ValueError(
"Args should be passed as `key=value`. Received `{}`".format(arg)
)
# --------------------------------
# ----- Parse Command Line -----
# --------------------------------
args_dict = {arg.split("=")[0]: arg.split("=")[1] for arg in args}
assert "template" in args_dict, "Please specify template=xxx"
template = read_template(args_dict["template"])
template_dict = {k: None for k in get_template_params(template)}
train_args = []
for k, v in args_dict.items():
if k == "verbose":
if v != "0":
verbose = True
elif k == "sbatch_path":
sbatch_path = v
elif k == "sbatch_base":
sbatch_base = Path(v).resolve()
elif k == "force_sbatchs":
force_sbatchs = v.lower() == "true"
elif k == "dev":
if v.lower() != "false":
dev = True
elif k == "escape":
if v.lower() != "false":
escape = True
elif k == "template":
template_name = v
elif k == "exp":
hp_exp_name = v
elif k == "n_search":
hp_search_nb = int(v)
elif k == "resume":
resume = f'"{v}"'
template_dict[k] = f'"{v}"'
elif k == "summary_dir":
if v.lower() == "none":
summary_dir = None
else:
summary_dir = Path(v)
elif k in template_dict:
template_dict[k] = v
else:
train_args.append(f"{k}={v}")
# ------------------------------------
# ----- Load Experiment Config -----
# ------------------------------------
if hp_exp_name is not None:
exp_path, exp_conf = read_exp_conf(hp_exp_name)
if "n_search" in exp_conf and hp_search_nb is None:
hp_search_nb = exp_conf["n_search"]
assert (
hp_search_nb is not None
), "n_search should be specified in a yaml file or from the command line"
hps = resolve(exp_conf, hp_search_nb)
else:
hps = [None]
# ---------------------------------
# ----- Run All Experiments -----
# ---------------------------------
if summary_dir is not None:
summary_dir.mkdir(exist_ok=True, parents=True)
summary = None
for hp_idx, hp in enumerate(hps):
# copy shared values
tmp_template_dict = template_dict.copy()
tmp_train_args = train_args.copy()
tmp_train_args_dict = {
arg.split("=")[0]: arg.split("=")[1] for arg in tmp_train_args
}
print_header(hp_idx)
# override shared values with run-specific values for run hp_idx/n_search
if hp is not None:
for k, v in hp.items():
if k == "resume" and resume is None:
resume = f'"{v}"'
# hp-search params to ignore
if k in hp_search_private:
continue
if k == "codeloc":
v = escape_path(v)
if k == "output":
Path(v).parent.mkdir(parents=True, exist_ok=True)
# override template params depending on exp config
if k in tmp_template_dict:
if template_dict[k] is None or is_sampled(k, exp_conf):
tmp_template_dict[k] = v
# store sampled / specified params in current tmp_train_args_dict
else:
if k in tmp_train_args_dict:
if is_sampled(k, exp_conf):
# warn if key was specified from the command line
tv = tmp_train_args_dict[k]
warn(
"\nWarning: overriding sampled config-file arg",
"{} to command-line value {}\n".format(k, tv),
)
else:
tmp_train_args_dict[k] = v
# create sbatch file where required
tmp_sbatch_path = None
if sbatch_path == "hash":
tmp_sbatch_name = "" if hp_exp_name is None else hp_exp_name[:14] + "_"
tmp_sbatch_name += now() + ".sh"
tmp_sbatch_path = sbatch_base / tmp_sbatch_name
tmp_sbatch_path.parent.mkdir(parents=True, exist_ok=True)
tmp_train_args_dict["sbatch_file"] = str(tmp_sbatch_path)
tmp_train_args_dict["exp_file"] = str(exp_path)
else:
tmp_sbatch_path = Path(sbatch_path).resolve()
summary = extend_summary(
summary, tmp_train_args_dict, tmp_template_dict, exclude=["sbatch_file"]
)
# format train.py's args and crop floats' precision to 5 digits
tmp_template_dict["train_args"] = " ".join(
sorted(
[
"{}={}".format(k, clean_arg(v))
for k, v in tmp_train_args_dict.items()
]
)
)
if "resume.py" in template and resume is None:
raise ValueError("No `resume` value but using a resume.py template")
# format template with clean dict (replace None with "")
sbatch = template.format(
**{
k: v if v is not None else ""
for k, v in tmp_template_dict.items()
if k in template_dict
}
)
# --------------------------------------
# ----- Execute `sbatch` Command -----
# --------------------------------------
if not dev or force_sbatchs:
if tmp_sbatch_path.exists():
print(f"Warning: overwriting {sbatch_path}")
# write sbatch file
with open(tmp_sbatch_path, "w") as f:
f.write(sbatch)
if not dev:
# escape special characters such as " " from sbatch_path's parent dir
parent = str(tmp_sbatch_path.parent)
if escape:
parent = escape_path(parent)
# create command to execute in a subprocess
command = "sbatch {}".format(tmp_sbatch_path.name)
# execute sbatch command & store output
command_output = subprocess.run(
command.split(), stdout=subprocess.PIPE, cwd=parent
)
command_output = "\n" + command_output.stdout.decode("utf-8") + "\n"
print(f"Running from {parent}:")
print(f"$ {command}")
# ---------------------------------
# ----- Summarize Execution -----
# ---------------------------------
if verbose:
print(C.BEIGE + C.ITALIC, "\n" + sbatch + C.ENDC)
if not dev:
print_box(command_output.strip())
jobID = parse_jobID(command_output.strip())
summary["Slurm JOBID"].append(jobID)
summary["Comet Link"].append(f"[{hp_idx}][{hp_idx}]")
print(
"{}{}Summary{} {}:".format(
C.UNDERLINE,
C.OKGREEN,
C.ENDC,
f"{C.WARNING}(DEV){C.ENDC}" if dev else "",
)
)
print(
" "
+ "\n ".join(
"{:10}: {}".format(k, v) for k, v in tmp_template_dict.items()
)
)
print_footer()
print(f"\nRan a total of {len(hps)} jobs{' in dev mode.' if dev else '.'}\n")
table, sum_path = search_summary_table(summary, summary_dir if not dev else None)
if table is not None:
print(table)
print(
"Add `[i]: https://...` at the end of a markdown document",
"to fill in the comet links.\n",
)
if summary_dir is None:
print("Add summary_dir=path to store the printed markdown table ⇪")
else:
print("Saved table in", str(sum_path))
if not dev:
print(
"Cancel entire experiment? \n$ scancel",
" ".join(map(str, summary["Slurm JOBID"])),
)