climateGAN / train.py
vict0rsch's picture
initial commit from `vict0rsch/climateGAN`
ce190ee
import logging
import os
from pathlib import Path
from time import sleep, time
import hydra
import yaml
from addict import Dict
from comet_ml import ExistingExperiment, Experiment
from omegaconf import OmegaConf
from climategan.trainer import Trainer
from climategan.utils import (
comet_kwargs,
copy_run_files,
env_to_path,
find_existing_training,
flatten_opts,
get_existing_comet_id,
get_git_branch,
get_git_revision_hash,
get_increased_path,
kill_job,
load_opts,
pprint,
)
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
hydra_config_path = Path(__file__).resolve().parent / "shared/trainer/config.yaml"
# requires hydra-core==0.11.3 and omegaconf==1.4.1
@hydra.main(config_path=hydra_config_path, strict=False)
def main(opts):
"""
Opts prevalence:
1. Load file specified in args.default (or shared/trainer/defaults.yaml
if none is provided)
2. Update with file specified in args.config (or no update if none is provided)
3. Update with parsed command-line arguments
e.g.
`python train.py args.config=config/large-lr.yaml data.loaders.batch_size=10`
loads defaults, overrides with values in large-lr.yaml and sets batch_size to 10
"""
# -----------------------------
# ----- Parse arguments -----
# -----------------------------
hydra_opts = Dict(OmegaConf.to_container(opts))
args = hydra_opts.pop("args", None)
auto_resumed = {}
config_path = args.config
if hydra_opts.train.resume:
out_ = str(env_to_path(hydra_opts.output_path))
config_path = Path(out_) / "opts.yaml"
if not config_path.exists():
config_path = None
print("WARNING: could not reuse the opts in {}".format(out_))
default = args.default or Path(__file__).parent / "shared/trainer/defaults.yaml"
# -----------------------
# ----- Load opts -----
# -----------------------
opts = load_opts(config_path, default=default, commandline_opts=hydra_opts)
if args.resume:
opts.train.resume = True
opts.jobID = os.environ.get("SLURM_JOBID")
opts.slurm_partition = os.environ.get("SLURM_JOB_PARTITION")
opts.output_path = str(env_to_path(opts.output_path))
print("Config output_path:", opts.output_path)
exp = comet_previous_id = None
# -------------------------------
# ----- Check output_path -----
# -------------------------------
# Auto-continue if same slurm job ID (=job was requeued)
if not opts.train.resume and opts.train.auto_resume:
print("\n\nTrying to auto-resume...")
existing_path = find_existing_training(opts)
if existing_path is not None and existing_path.exists():
auto_resumed["original output_path"] = str(opts.output_path)
auto_resumed["existing_path"] = str(existing_path)
opts.train.resume = True
opts.output_path = str(existing_path)
# Still not resuming: creating new output path
if not opts.train.resume:
opts.output_path = str(get_increased_path(opts.output_path))
Path(opts.output_path).mkdir(parents=True, exist_ok=True)
# Copy the opts's sbatch_file to output_path
copy_run_files(opts)
# store git hash
opts.git_hash = get_git_revision_hash()
opts.git_branch = get_git_branch()
if not args.no_comet:
# ----------------------------------
# ----- Set Comet Experiment -----
# ----------------------------------
if opts.train.resume:
# Is resuming: get existing comet exp id
assert Path(opts.output_path).exists(), "Output_path does not exist"
comet_previous_id = get_existing_comet_id(opts.output_path)
# Continue existing experiment
if comet_previous_id is None:
print("WARNING could not retreive previous comet id")
print(f"from {opts.output_path}")
else:
print("Continuing previous experiment", comet_previous_id)
auto_resumed["continuing exp id"] = comet_previous_id
exp = ExistingExperiment(
previous_experiment=comet_previous_id, **comet_kwargs
)
print("Comet Experiment resumed")
if exp is None:
# Create new experiment
print("Starting new experiment")
exp = Experiment(project_name="climategan", **comet_kwargs)
exp.log_asset_folder(
str(Path(__file__).parent / "climategan"),
recursive=True,
log_file_name=True,
)
exp.log_asset(str(Path(__file__)))
# Log note
if args.note:
exp.log_parameter("note", args.note)
# Merge and log tags
if args.comet_tags or opts.comet.tags:
tags = set([f"branch:{opts.git_branch}"])
if args.comet_tags:
tags.update(args.comet_tags)
if opts.comet.tags:
tags.update(opts.comet.tags)
opts.comet.tags = list(tags)
print("Logging to comet.ml with tags", opts.comet.tags)
exp.add_tags(opts.comet.tags)
# Log all opts
exp.log_parameters(flatten_opts(opts))
if auto_resumed:
exp.log_text("\n".join(f"{k:20}: {v}" for k, v in auto_resumed.items()))
# allow some time for comet to get its url
sleep(1)
# Save comet exp url
url_path = get_increased_path(Path(opts.output_path) / "comet_url.txt")
with open(url_path, "w") as f:
f.write(exp.url)
# Save config file
opts_path = get_increased_path(Path(opts.output_path) / "opts.yaml")
with (opts_path).open("w") as f:
yaml.safe_dump(opts.to_dict(), f)
pprint("Running model in", opts.output_path)
# -------------------
# ----- Train -----
# -------------------
trainer = Trainer(opts, comet_exp=exp, verbose=1)
trainer.logger.time.start_time = time()
trainer.setup()
trainer.train()
# -----------------------------
# ----- End of training -----
# -----------------------------
pprint("Done training")
kill_job(opts.jobID)
if __name__ == "__main__":
main()