|
import logging |
|
import os |
|
from pathlib import Path |
|
from time import sleep, time |
|
|
|
import hydra |
|
import yaml |
|
from addict import Dict |
|
from comet_ml import ExistingExperiment, Experiment |
|
from omegaconf import OmegaConf |
|
|
|
from climategan.trainer import Trainer |
|
from climategan.utils import ( |
|
comet_kwargs, |
|
copy_run_files, |
|
env_to_path, |
|
find_existing_training, |
|
flatten_opts, |
|
get_existing_comet_id, |
|
get_git_branch, |
|
get_git_revision_hash, |
|
get_increased_path, |
|
kill_job, |
|
load_opts, |
|
pprint, |
|
) |
|
|
|
logging.basicConfig() |
|
logging.getLogger().setLevel(logging.ERROR) |
|
|
|
hydra_config_path = Path(__file__).resolve().parent / "shared/trainer/config.yaml" |
|
|
|
|
|
|
|
@hydra.main(config_path=hydra_config_path, strict=False) |
|
def main(opts): |
|
""" |
|
Opts prevalence: |
|
1. Load file specified in args.default (or shared/trainer/defaults.yaml |
|
if none is provided) |
|
2. Update with file specified in args.config (or no update if none is provided) |
|
3. Update with parsed command-line arguments |
|
|
|
e.g. |
|
`python train.py args.config=config/large-lr.yaml data.loaders.batch_size=10` |
|
loads defaults, overrides with values in large-lr.yaml and sets batch_size to 10 |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
hydra_opts = Dict(OmegaConf.to_container(opts)) |
|
args = hydra_opts.pop("args", None) |
|
auto_resumed = {} |
|
|
|
config_path = args.config |
|
|
|
if hydra_opts.train.resume: |
|
out_ = str(env_to_path(hydra_opts.output_path)) |
|
config_path = Path(out_) / "opts.yaml" |
|
if not config_path.exists(): |
|
config_path = None |
|
print("WARNING: could not reuse the opts in {}".format(out_)) |
|
|
|
default = args.default or Path(__file__).parent / "shared/trainer/defaults.yaml" |
|
|
|
|
|
|
|
|
|
|
|
opts = load_opts(config_path, default=default, commandline_opts=hydra_opts) |
|
if args.resume: |
|
opts.train.resume = True |
|
|
|
opts.jobID = os.environ.get("SLURM_JOBID") |
|
opts.slurm_partition = os.environ.get("SLURM_JOB_PARTITION") |
|
opts.output_path = str(env_to_path(opts.output_path)) |
|
print("Config output_path:", opts.output_path) |
|
|
|
exp = comet_previous_id = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not opts.train.resume and opts.train.auto_resume: |
|
print("\n\nTrying to auto-resume...") |
|
existing_path = find_existing_training(opts) |
|
if existing_path is not None and existing_path.exists(): |
|
auto_resumed["original output_path"] = str(opts.output_path) |
|
auto_resumed["existing_path"] = str(existing_path) |
|
opts.train.resume = True |
|
opts.output_path = str(existing_path) |
|
|
|
|
|
if not opts.train.resume: |
|
opts.output_path = str(get_increased_path(opts.output_path)) |
|
Path(opts.output_path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
copy_run_files(opts) |
|
|
|
opts.git_hash = get_git_revision_hash() |
|
opts.git_branch = get_git_branch() |
|
|
|
if not args.no_comet: |
|
|
|
|
|
|
|
|
|
if opts.train.resume: |
|
|
|
assert Path(opts.output_path).exists(), "Output_path does not exist" |
|
|
|
comet_previous_id = get_existing_comet_id(opts.output_path) |
|
|
|
if comet_previous_id is None: |
|
print("WARNING could not retreive previous comet id") |
|
print(f"from {opts.output_path}") |
|
else: |
|
print("Continuing previous experiment", comet_previous_id) |
|
auto_resumed["continuing exp id"] = comet_previous_id |
|
exp = ExistingExperiment( |
|
previous_experiment=comet_previous_id, **comet_kwargs |
|
) |
|
print("Comet Experiment resumed") |
|
|
|
if exp is None: |
|
|
|
print("Starting new experiment") |
|
exp = Experiment(project_name="climategan", **comet_kwargs) |
|
exp.log_asset_folder( |
|
str(Path(__file__).parent / "climategan"), |
|
recursive=True, |
|
log_file_name=True, |
|
) |
|
exp.log_asset(str(Path(__file__))) |
|
|
|
|
|
if args.note: |
|
exp.log_parameter("note", args.note) |
|
|
|
|
|
if args.comet_tags or opts.comet.tags: |
|
tags = set([f"branch:{opts.git_branch}"]) |
|
if args.comet_tags: |
|
tags.update(args.comet_tags) |
|
if opts.comet.tags: |
|
tags.update(opts.comet.tags) |
|
opts.comet.tags = list(tags) |
|
print("Logging to comet.ml with tags", opts.comet.tags) |
|
exp.add_tags(opts.comet.tags) |
|
|
|
|
|
exp.log_parameters(flatten_opts(opts)) |
|
if auto_resumed: |
|
exp.log_text("\n".join(f"{k:20}: {v}" for k, v in auto_resumed.items())) |
|
|
|
|
|
sleep(1) |
|
|
|
|
|
url_path = get_increased_path(Path(opts.output_path) / "comet_url.txt") |
|
with open(url_path, "w") as f: |
|
f.write(exp.url) |
|
|
|
|
|
opts_path = get_increased_path(Path(opts.output_path) / "opts.yaml") |
|
with (opts_path).open("w") as f: |
|
yaml.safe_dump(opts.to_dict(), f) |
|
|
|
pprint("Running model in", opts.output_path) |
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(opts, comet_exp=exp, verbose=1) |
|
trainer.logger.time.start_time = time() |
|
trainer.setup() |
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
pprint("Done training") |
|
kill_job(opts.jobID) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|