|
import datetime |
|
import itertools |
|
import os |
|
import re |
|
import subprocess |
|
import sys |
|
from collections import defaultdict |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import yaml |
|
|
|
|
|
def flatten_conf(conf, to={}, parents=[]): |
|
""" |
|
Flattens a configuration dict: nested dictionaries are flattened |
|
as key1.key2.key3 = value |
|
|
|
conf.yaml: |
|
```yaml |
|
a: 1 |
|
b: |
|
c: 2 |
|
d: |
|
e: 3 |
|
g: |
|
sample: sequential |
|
from: [4, 5] |
|
``` |
|
|
|
Is flattened to |
|
|
|
{ |
|
"a": 1, |
|
"b.c": 2, |
|
"b.d.e": 3, |
|
"b.g": { |
|
"sample": "sequential", |
|
"from": [4, 5] |
|
} |
|
} |
|
|
|
Does not affect sampling dicts. |
|
|
|
Args: |
|
conf (dict): the configuration to flatten |
|
new (dict, optional): the target flatenned dict. Defaults to {}. |
|
parents (list, optional): a final value's list of parents. Defaults to []. |
|
""" |
|
for k, v in conf.items(): |
|
if isinstance(v, dict) and "sample" not in v: |
|
flatten_conf(v, to, parents + [k]) |
|
else: |
|
new_k = ".".join([str(p) for p in parents + [k]]) |
|
to[new_k] = v |
|
|
|
|
|
def env_to_path(path): |
|
"""Transorms an environment variable mention in a json |
|
into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds |
|
|
|
Args: |
|
path (str): path potentially containing the env variable |
|
|
|
""" |
|
path_elements = path.split("/") |
|
new_path = [] |
|
for el in path_elements: |
|
if "$" in el: |
|
new_path.append(os.environ[el.replace("$", "")]) |
|
else: |
|
new_path.append(el) |
|
return "/".join(new_path) |
|
|
|
|
|
class C: |
|
HEADER = "\033[95m" |
|
OKBLUE = "\033[94m" |
|
OKGREEN = "\033[92m" |
|
WARNING = "\033[93m" |
|
FAIL = "\033[91m" |
|
ENDC = "\033[0m" |
|
BOLD = "\033[1m" |
|
UNDERLINE = "\033[4m" |
|
ITALIC = "\33[3m" |
|
BEIGE = "\33[36m" |
|
|
|
|
|
def escape_path(path): |
|
p = str(path) |
|
return p.replace(" ", "\ ").replace("(", "\(").replace(")", "\)") |
|
|
|
|
|
def warn(*args, **kwargs): |
|
print("{}{}{}".format(C.WARNING, " ".join(args), C.ENDC), **kwargs) |
|
|
|
|
|
def parse_jobID(command_output): |
|
""" |
|
get job id from successful sbatch command output like |
|
`Submitted batch job 599583` |
|
|
|
Args: |
|
command_output (str): sbatch command's output |
|
|
|
Returns: |
|
int: the slurm job's ID |
|
""" |
|
command_output = command_output.strip() |
|
if isinstance(command_output, str): |
|
if "Submitted batch job" in command_output: |
|
return int(command_output.split()[-1]) |
|
|
|
return -1 |
|
|
|
|
|
def now(): |
|
return str(datetime.datetime.now()).replace(" ", "_") |
|
|
|
|
|
def cols(): |
|
try: |
|
col = os.get_terminal_size().columns |
|
except Exception: |
|
col = 50 |
|
return col |
|
|
|
|
|
def print_box(txt): |
|
if not txt: |
|
txt = "{}{}ERROR ⇪{}".format(C.BOLD, C.FAIL, C.ENDC) |
|
lt = 7 |
|
else: |
|
lt = len(txt) |
|
nlt = lt + 12 |
|
txt = "|" + " " * 5 + txt + " " * 5 + "|" |
|
line = "-" * nlt |
|
empty = "|" + " " * (nlt - 2) + "|" |
|
print(line) |
|
print(empty) |
|
print(txt) |
|
print(empty) |
|
print(line) |
|
|
|
|
|
def print_header(idx): |
|
b = C.BOLD |
|
bl = C.OKBLUE |
|
e = C.ENDC |
|
char = "≡" |
|
c = cols() |
|
|
|
txt = " " * 20 |
|
txt += f"{b}{bl}Run {idx}{e}" |
|
txt += " " * 20 |
|
ln = len(txt) - len(b) - len(bl) - len(e) |
|
t = int(np.floor((c - ln) / 2)) |
|
tt = int(np.ceil((c - ln) / 2)) |
|
|
|
print(char * c) |
|
print(char * t + " " * ln + char * tt) |
|
print(char * t + txt + char * tt) |
|
print(char * t + " " * ln + char * tt) |
|
print(char * c) |
|
|
|
|
|
def print_footer(): |
|
c = cols() |
|
char = "﹎" |
|
print() |
|
print(char * (c // len(char))) |
|
print() |
|
print(" " * (c // 2) + "•" + " " * (c - c // 2 - 1)) |
|
print() |
|
|
|
|
|
def extend_summary(summary, tmp_train_args_dict, tmp_template_dict, exclude=[]): |
|
exclude = set(exclude) |
|
if summary is None: |
|
summary = defaultdict(list) |
|
for k, v in tmp_template_dict.items(): |
|
if k not in exclude: |
|
summary[k].append(v) |
|
for k, v in tmp_train_args_dict.items(): |
|
if k not in exclude: |
|
if isinstance(v, list): |
|
v = str(v) |
|
summary[k].append(v) |
|
return summary |
|
|
|
|
|
def search_summary_table(summary, summary_dir=None): |
|
|
|
summary = {k: v for k, v in summary.items() if len(set(v)) > 1} |
|
|
|
|
|
if not summary: |
|
return None, None |
|
|
|
|
|
n_searches = len(list(summary.values())[0]) |
|
|
|
|
|
print( |
|
"{}{}{}Varying values across {} experiments:{}\n".format( |
|
C.OKBLUE, |
|
C.BOLD, |
|
C.UNDERLINE, |
|
n_searches, |
|
C.ENDC, |
|
) |
|
) |
|
|
|
|
|
first_col = { |
|
"len": 8, |
|
"str": ["| Exp. |", "|:----:|"] |
|
+ [ |
|
"| {0:^{1}} |".format(i, 4) for i in range(n_searches) |
|
], |
|
} |
|
|
|
print_columns = [[first_col]] |
|
file_columns = [first_col] |
|
for k in sorted(summary.keys()): |
|
v = summary[k] |
|
col_title = f" {k} |" |
|
col_blank_line = f":{'-' * len(k)}-|" |
|
col_values = [ |
|
" {0:{1}} |".format( |
|
crop_string( |
|
str(crop_float(v[idx], min([5, len(k) - 2]))), len(k) |
|
), |
|
len(k), |
|
) |
|
for idx in range(len(v)) |
|
] |
|
|
|
|
|
col = {"len": len(k) + 3, "str": [col_title, col_blank_line] + col_values} |
|
|
|
|
|
|
|
if sum(c["len"] for c in print_columns[-1]) + col["len"] >= cols(): |
|
print_columns.append([first_col]) |
|
|
|
|
|
print_columns[-1].append(col) |
|
file_columns.append(col) |
|
|
|
print_table = "" |
|
|
|
for colgroup in print_columns: |
|
|
|
for i in range(n_searches + 2): |
|
|
|
for col in colgroup: |
|
print_table += col["str"][i] |
|
|
|
print_table += "\n" |
|
|
|
|
|
print_table += "\n" |
|
|
|
file_table = "" |
|
for i in range(n_searches + 2): |
|
|
|
for col in file_columns: |
|
file_table += col["str"][i] |
|
|
|
file_table += "\n" |
|
|
|
summary_path = None |
|
if summary_dir is not None: |
|
summary_path = summary_dir / (now() + ".md") |
|
with summary_path.open("w") as f: |
|
f.write(file_table.strip()) |
|
|
|
return print_table, summary_path |
|
|
|
|
|
def clean_arg(v): |
|
""" |
|
chain cleaning function |
|
|
|
Args: |
|
v (any): arg to pass to train.py |
|
|
|
Returns: |
|
str: parsed value to string |
|
""" |
|
return stringify_list(crop_float(quote_string(resolve_env(v)))) |
|
|
|
|
|
def resolve_env(v): |
|
""" |
|
resolve env variables in paths |
|
|
|
Args: |
|
v (any): arg to pass to train.py |
|
|
|
Returns: |
|
str: try and resolve an env variable |
|
""" |
|
if isinstance(v, str): |
|
try: |
|
if "$" in v: |
|
if "/" in v: |
|
v = env_to_path(v) |
|
else: |
|
_v = os.environ.get(v) |
|
if _v is not None: |
|
v = _v |
|
except Exception: |
|
pass |
|
return v |
|
|
|
|
|
def stringify_list(v): |
|
""" |
|
Stringify list (with double quotes) so that it can be passed a an argument |
|
to train.py's hydra command-line parsing |
|
|
|
Args: |
|
v (any): value to clean |
|
|
|
Returns: |
|
any: type of v, str if v was a list |
|
""" |
|
if isinstance(v, list): |
|
return '"{}"'.format(str(v).replace('"', "'")) |
|
if isinstance(v, str): |
|
if v.startswith("[") and v.endswith("]"): |
|
return f'"{v}"' |
|
return v |
|
|
|
|
|
def quote_string(v): |
|
""" |
|
Add double quotes around string if it contains a " " or an = |
|
|
|
Args: |
|
v (any): value to clean |
|
|
|
Returns: |
|
any: type of v, quoted if v is a string with " " or = |
|
""" |
|
if isinstance(v, str): |
|
if " " in v or "=" in v: |
|
return f'"{v}"' |
|
return v |
|
|
|
|
|
def crop_float(v, k=5): |
|
""" |
|
If v is a float, crop precision to 5 digits and return v as a str |
|
|
|
Args: |
|
v (any): value to crop if float |
|
|
|
Returns: |
|
any: cropped float as str if v is a float, original v otherwise |
|
""" |
|
if isinstance(v, float): |
|
return "{0:.{1}g}".format(v, k) |
|
return v |
|
|
|
|
|
def compute_n_search(conf): |
|
""" |
|
Compute the number of searchs to do if using -1 as n_search and using |
|
cartesian or sequential search |
|
|
|
Args: |
|
conf (dict): experimental configuration |
|
|
|
Returns: |
|
int: size of the cartesian product or length of longest sequential field |
|
""" |
|
samples = defaultdict(list) |
|
for k, v in conf.items(): |
|
if not isinstance(v, dict) or "sample" not in v: |
|
continue |
|
samples[v["sample"]].append(v) |
|
|
|
totals = [] |
|
|
|
if "cartesian" in samples: |
|
total = 1 |
|
for s in samples["cartesian"]: |
|
total *= len(s["from"]) |
|
totals.append(total) |
|
if "sequential" in samples: |
|
total = max(map(len, [s["from"] for s in samples["sequential"]])) |
|
totals.append(total) |
|
|
|
if totals: |
|
return max(totals) |
|
|
|
raise ValueError( |
|
"Used n_search=-1 without any field being 'cartesian' or 'sequential'" |
|
) |
|
|
|
|
|
def crop_string(s, k=10): |
|
if len(s) <= k: |
|
return s |
|
else: |
|
return s[: k - 2] + ".." |
|
|
|
|
|
def sample_param(sample_dict): |
|
"""sample a value (hyperparameter) from the instruction in the |
|
sample dict: |
|
{ |
|
"sample": "range | list", |
|
"from": [min, max, step] | [v0, v1, v2 etc.] |
|
} |
|
if range, as np.arange is used, "from" MUST be a list, but may contain |
|
only 1 (=min) or 2 (min and max) values, not necessarily 3 |
|
|
|
Args: |
|
sample_dict (dict): instructions to sample a value |
|
|
|
Returns: |
|
scalar: sampled value |
|
""" |
|
if not isinstance(sample_dict, dict) or "sample" not in sample_dict: |
|
return sample_dict |
|
|
|
if sample_dict["sample"] == "cartesian": |
|
assert isinstance( |
|
sample_dict["from"], list |
|
), "{}'s `from` field MUST be a list, found {}".format( |
|
sample_dict["sample"], sample_dict["from"] |
|
) |
|
return "__cartesian__" |
|
|
|
if sample_dict["sample"] == "sequential": |
|
assert isinstance( |
|
sample_dict["from"], list |
|
), "{}'s `from` field MUST be a list, found {}".format( |
|
sample_dict["sample"], sample_dict["from"] |
|
) |
|
return "__sequential__" |
|
|
|
if sample_dict["sample"] == "range": |
|
return np.random.choice(np.arange(*sample_dict["from"])) |
|
|
|
if sample_dict["sample"] == "list": |
|
return np.random.choice(sample_dict["from"]) |
|
|
|
if sample_dict["sample"] == "uniform": |
|
return np.random.uniform(*sample_dict["from"]) |
|
|
|
raise ValueError("Unknown sample type in dict " + str(sample_dict)) |
|
|
|
|
|
def sample_sequentials(sequential_keys, exp, idx): |
|
""" |
|
Samples sequentially from the "from" values specified in each key of the |
|
experimental configuration which have sample == "sequential" |
|
Unlike `cartesian` sampling, `sequential` sampling iterates *independently* |
|
over each keys |
|
|
|
Args: |
|
sequential_keys (list): keys to be sampled sequentially |
|
exp (dict): experimental config |
|
idx (int): index of the current sample |
|
|
|
Returns: |
|
conf: sampled dict |
|
""" |
|
conf = {} |
|
for k in sequential_keys: |
|
v = exp[k]["from"] |
|
conf[k] = v[idx % len(v)] |
|
return conf |
|
|
|
|
|
def sample_cartesians(cartesian_keys, exp, idx): |
|
""" |
|
Returns the `idx`th item in the cartesian product of all cartesian keys to |
|
be sampled. |
|
|
|
Args: |
|
cartesian_keys (list): keys in the experimental configuration that are to |
|
be used in the full cartesian product |
|
exp (dict): experimental configuration |
|
idx (int): index of the current sample |
|
|
|
Returns: |
|
dict: sampled point in the cartesian space (with keys = cartesian_keys) |
|
""" |
|
conf = {} |
|
cartesian_values = [exp[key]["from"] for key in cartesian_keys] |
|
product = list(itertools.product(*cartesian_values)) |
|
for k, v in zip(cartesian_keys, product[idx % len(product)]): |
|
conf[k] = v |
|
return conf |
|
|
|
|
|
def resolve(hp_conf, nb): |
|
""" |
|
Samples parameters parametrized in `exp`: should be a dict with |
|
values which fit `sample_params(dic)`'s API |
|
|
|
Args: |
|
exp (dict): experiment's parametrization |
|
nb (int): number of experiments to sample |
|
|
|
Returns: |
|
dict: sampled configuration |
|
""" |
|
if nb == -1: |
|
nb = compute_n_search(hp_conf) |
|
|
|
confs = [] |
|
for idx in range(nb): |
|
conf = {} |
|
cartesians = [] |
|
sequentials = [] |
|
for k, v in hp_conf.items(): |
|
candidate = sample_param(v) |
|
if candidate == "__cartesian__": |
|
cartesians.append(k) |
|
elif candidate == "__sequential__": |
|
sequentials.append(k) |
|
else: |
|
conf[k] = candidate |
|
if sequentials: |
|
conf.update(sample_sequentials(sequentials, hp_conf, idx)) |
|
if cartesians: |
|
conf.update(sample_cartesians(cartesians, hp_conf, idx)) |
|
confs.append(conf) |
|
return confs |
|
|
|
|
|
def get_template_params(template): |
|
""" |
|
extract args in template str as {arg} |
|
|
|
Args: |
|
template (str): sbatch template string |
|
|
|
Returns: |
|
list(str): Args required to format the template string |
|
""" |
|
return map( |
|
lambda s: s.replace("{", "").replace("}", ""), |
|
re.findall("\{.*?\}", template), |
|
) |
|
|
|
|
|
def read_exp_conf(name): |
|
""" |
|
Read hp search configuration from shared/experiment/ |
|
specified with or without the .yaml extension |
|
|
|
Args: |
|
name (str): name of the template to find in shared/experiment/ |
|
|
|
Returns: |
|
Tuple(Path, dict): file path and loaded dict |
|
""" |
|
if ".yaml" not in name: |
|
name += ".yaml" |
|
paths = [] |
|
dirs = ["shared", "config"] |
|
for d in dirs: |
|
path = Path(__file__).parent / d / "experiment" / name |
|
if path.exists(): |
|
paths.append(path.resolve()) |
|
|
|
if len(paths) == 0: |
|
failed = [Path(__file__).parent / d / "experiment" for d in dirs] |
|
s = "Could not find search config {} in :\n".format(name) |
|
for fd in failed: |
|
s += str(fd) + "\nAvailable:\n" |
|
for ym in fd.glob("*.yaml"): |
|
s += " " + ym.name + "\n" |
|
raise ValueError(s) |
|
|
|
if len(paths) == 2: |
|
print( |
|
"Warning: found 2 relevant files for search config:\n{}".format( |
|
"\n".join(paths) |
|
) |
|
) |
|
print("Using {}".format(paths[-1])) |
|
|
|
with paths[-1].open("r") as f: |
|
conf = yaml.safe_load(f) |
|
|
|
flat_conf = {} |
|
flatten_conf(conf, to=flat_conf) |
|
|
|
return (paths[-1], flat_conf) |
|
|
|
|
|
def read_template(name): |
|
""" |
|
Read template from shared/template/ specified with or without the .sh extension |
|
|
|
Args: |
|
name (str): name of the template to find in shared/template/ |
|
|
|
Returns: |
|
str: file's content as 1 string |
|
""" |
|
if ".sh" not in name: |
|
name += ".sh" |
|
paths = [] |
|
dirs = ["shared", "config"] |
|
for d in dirs: |
|
path = Path(__file__).parent / d / "template" / name |
|
if path.exists(): |
|
paths.append(path) |
|
|
|
if len(paths) == 0: |
|
failed = [Path(__file__).parent / d / "template" for d in dirs] |
|
s = "Could not find template {} in :\n".format(name) |
|
for fd in failed: |
|
s += str(fd) + "\nAvailable:\n" |
|
for ym in fd.glob("*.sh"): |
|
s += " " + ym.name + "\n" |
|
raise ValueError(s) |
|
|
|
if len(paths) == 2: |
|
print("Warning: found 2 relevant template files:\n{}".format("\n".join(paths))) |
|
print("Using {}".format(paths[-1])) |
|
|
|
with paths[-1].open("r") as f: |
|
return f.read() |
|
|
|
|
|
def is_sampled(key, conf): |
|
""" |
|
Is a key sampled or constant? Returns true if conf is empty |
|
|
|
Args: |
|
key (str): key to check |
|
conf (dict): hyper parameter search configuration dict |
|
|
|
Returns: |
|
bool: key is sampled? |
|
""" |
|
return not conf or ( |
|
key in conf and isinstance(conf[key], dict) and "sample" in conf[key] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
""" |
|
Notes: |
|
* Must provide template name as template=name |
|
* `name`.sh should be in shared/template/ |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
args = sys.argv[1:] |
|
command_output = "" |
|
user = os.environ.get("USER") |
|
home = os.environ.get("HOME") |
|
exp_conf = {} |
|
dev = False |
|
escape = False |
|
verbose = False |
|
template_name = None |
|
hp_exp_name = None |
|
hp_search_nb = None |
|
exp_path = None |
|
resume = None |
|
force_sbatchs = False |
|
sbatch_base = Path(home) / "climategan_sbatchs" |
|
summary_dir = Path(home) / "climategan_exp_summaries" |
|
|
|
hp_search_private = set(["n_search", "template", "search", "summary_dir"]) |
|
|
|
sbatch_path = "hash" |
|
|
|
|
|
|
|
|
|
|
|
for arg in args: |
|
if "=" not in arg or " = " in arg: |
|
raise ValueError( |
|
"Args should be passed as `key=value`. Received `{}`".format(arg) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
args_dict = {arg.split("=")[0]: arg.split("=")[1] for arg in args} |
|
|
|
assert "template" in args_dict, "Please specify template=xxx" |
|
template = read_template(args_dict["template"]) |
|
template_dict = {k: None for k in get_template_params(template)} |
|
|
|
train_args = [] |
|
for k, v in args_dict.items(): |
|
|
|
if k == "verbose": |
|
if v != "0": |
|
verbose = True |
|
|
|
elif k == "sbatch_path": |
|
sbatch_path = v |
|
|
|
elif k == "sbatch_base": |
|
sbatch_base = Path(v).resolve() |
|
|
|
elif k == "force_sbatchs": |
|
force_sbatchs = v.lower() == "true" |
|
|
|
elif k == "dev": |
|
if v.lower() != "false": |
|
dev = True |
|
|
|
elif k == "escape": |
|
if v.lower() != "false": |
|
escape = True |
|
|
|
elif k == "template": |
|
template_name = v |
|
|
|
elif k == "exp": |
|
hp_exp_name = v |
|
|
|
elif k == "n_search": |
|
hp_search_nb = int(v) |
|
|
|
elif k == "resume": |
|
resume = f'"{v}"' |
|
template_dict[k] = f'"{v}"' |
|
|
|
elif k == "summary_dir": |
|
if v.lower() == "none": |
|
summary_dir = None |
|
else: |
|
summary_dir = Path(v) |
|
|
|
elif k in template_dict: |
|
template_dict[k] = v |
|
|
|
else: |
|
train_args.append(f"{k}={v}") |
|
|
|
|
|
|
|
|
|
|
|
if hp_exp_name is not None: |
|
exp_path, exp_conf = read_exp_conf(hp_exp_name) |
|
if "n_search" in exp_conf and hp_search_nb is None: |
|
hp_search_nb = exp_conf["n_search"] |
|
|
|
assert ( |
|
hp_search_nb is not None |
|
), "n_search should be specified in a yaml file or from the command line" |
|
|
|
hps = resolve(exp_conf, hp_search_nb) |
|
|
|
else: |
|
hps = [None] |
|
|
|
|
|
|
|
|
|
if summary_dir is not None: |
|
summary_dir.mkdir(exist_ok=True, parents=True) |
|
summary = None |
|
|
|
for hp_idx, hp in enumerate(hps): |
|
|
|
|
|
tmp_template_dict = template_dict.copy() |
|
tmp_train_args = train_args.copy() |
|
tmp_train_args_dict = { |
|
arg.split("=")[0]: arg.split("=")[1] for arg in tmp_train_args |
|
} |
|
print_header(hp_idx) |
|
|
|
if hp is not None: |
|
for k, v in hp.items(): |
|
if k == "resume" and resume is None: |
|
resume = f'"{v}"' |
|
|
|
if k in hp_search_private: |
|
continue |
|
|
|
if k == "codeloc": |
|
v = escape_path(v) |
|
|
|
if k == "output": |
|
Path(v).parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if k in tmp_template_dict: |
|
if template_dict[k] is None or is_sampled(k, exp_conf): |
|
tmp_template_dict[k] = v |
|
|
|
else: |
|
if k in tmp_train_args_dict: |
|
if is_sampled(k, exp_conf): |
|
|
|
tv = tmp_train_args_dict[k] |
|
warn( |
|
"\nWarning: overriding sampled config-file arg", |
|
"{} to command-line value {}\n".format(k, tv), |
|
) |
|
else: |
|
tmp_train_args_dict[k] = v |
|
|
|
|
|
tmp_sbatch_path = None |
|
if sbatch_path == "hash": |
|
tmp_sbatch_name = "" if hp_exp_name is None else hp_exp_name[:14] + "_" |
|
tmp_sbatch_name += now() + ".sh" |
|
tmp_sbatch_path = sbatch_base / tmp_sbatch_name |
|
tmp_sbatch_path.parent.mkdir(parents=True, exist_ok=True) |
|
tmp_train_args_dict["sbatch_file"] = str(tmp_sbatch_path) |
|
tmp_train_args_dict["exp_file"] = str(exp_path) |
|
else: |
|
tmp_sbatch_path = Path(sbatch_path).resolve() |
|
|
|
summary = extend_summary( |
|
summary, tmp_train_args_dict, tmp_template_dict, exclude=["sbatch_file"] |
|
) |
|
|
|
|
|
tmp_template_dict["train_args"] = " ".join( |
|
sorted( |
|
[ |
|
"{}={}".format(k, clean_arg(v)) |
|
for k, v in tmp_train_args_dict.items() |
|
] |
|
) |
|
) |
|
|
|
if "resume.py" in template and resume is None: |
|
raise ValueError("No `resume` value but using a resume.py template") |
|
|
|
|
|
sbatch = template.format( |
|
**{ |
|
k: v if v is not None else "" |
|
for k, v in tmp_template_dict.items() |
|
if k in template_dict |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
if not dev or force_sbatchs: |
|
if tmp_sbatch_path.exists(): |
|
print(f"Warning: overwriting {sbatch_path}") |
|
|
|
|
|
with open(tmp_sbatch_path, "w") as f: |
|
f.write(sbatch) |
|
|
|
if not dev: |
|
|
|
parent = str(tmp_sbatch_path.parent) |
|
if escape: |
|
parent = escape_path(parent) |
|
|
|
|
|
command = "sbatch {}".format(tmp_sbatch_path.name) |
|
|
|
command_output = subprocess.run( |
|
command.split(), stdout=subprocess.PIPE, cwd=parent |
|
) |
|
command_output = "\n" + command_output.stdout.decode("utf-8") + "\n" |
|
|
|
print(f"Running from {parent}:") |
|
print(f"$ {command}") |
|
|
|
|
|
|
|
|
|
if verbose: |
|
print(C.BEIGE + C.ITALIC, "\n" + sbatch + C.ENDC) |
|
if not dev: |
|
print_box(command_output.strip()) |
|
jobID = parse_jobID(command_output.strip()) |
|
summary["Slurm JOBID"].append(jobID) |
|
|
|
summary["Comet Link"].append(f"[{hp_idx}][{hp_idx}]") |
|
|
|
print( |
|
"{}{}Summary{} {}:".format( |
|
C.UNDERLINE, |
|
C.OKGREEN, |
|
C.ENDC, |
|
f"{C.WARNING}(DEV){C.ENDC}" if dev else "", |
|
) |
|
) |
|
print( |
|
" " |
|
+ "\n ".join( |
|
"{:10}: {}".format(k, v) for k, v in tmp_template_dict.items() |
|
) |
|
) |
|
print_footer() |
|
|
|
print(f"\nRan a total of {len(hps)} jobs{' in dev mode.' if dev else '.'}\n") |
|
|
|
table, sum_path = search_summary_table(summary, summary_dir if not dev else None) |
|
if table is not None: |
|
print(table) |
|
print( |
|
"Add `[i]: https://...` at the end of a markdown document", |
|
"to fill in the comet links.\n", |
|
) |
|
if summary_dir is None: |
|
print("Add summary_dir=path to store the printed markdown table ⇪") |
|
else: |
|
print("Saved table in", str(sum_path)) |
|
|
|
if not dev: |
|
print( |
|
"Cancel entire experiment? \n$ scancel", |
|
" ".join(map(str, summary["Slurm JOBID"])), |
|
) |
|
|