Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from pathlib import Path | |
| from time import sleep, time | |
| import hydra | |
| import yaml | |
| from addict import Dict | |
| from comet_ml import ExistingExperiment, Experiment | |
| from omegaconf import OmegaConf | |
| from climategan.trainer import Trainer | |
| from climategan.utils import ( | |
| comet_kwargs, | |
| copy_run_files, | |
| env_to_path, | |
| find_existing_training, | |
| flatten_opts, | |
| get_existing_comet_id, | |
| get_git_branch, | |
| get_git_revision_hash, | |
| get_increased_path, | |
| kill_job, | |
| load_opts, | |
| pprint, | |
| ) | |
| logging.basicConfig() | |
| logging.getLogger().setLevel(logging.ERROR) | |
| hydra_config_path = Path(__file__).resolve().parent / "shared/trainer/config.yaml" | |
| # requires hydra-core==0.11.3 and omegaconf==1.4.1 | |
| def main(opts): | |
| """ | |
| Opts prevalence: | |
| 1. Load file specified in args.default (or shared/trainer/defaults.yaml | |
| if none is provided) | |
| 2. Update with file specified in args.config (or no update if none is provided) | |
| 3. Update with parsed command-line arguments | |
| e.g. | |
| `python train.py args.config=config/large-lr.yaml data.loaders.batch_size=10` | |
| loads defaults, overrides with values in large-lr.yaml and sets batch_size to 10 | |
| """ | |
| # ----------------------------- | |
| # ----- Parse arguments ----- | |
| # ----------------------------- | |
| hydra_opts = Dict(OmegaConf.to_container(opts)) | |
| args = hydra_opts.pop("args", None) | |
| auto_resumed = {} | |
| config_path = args.config | |
| if hydra_opts.train.resume: | |
| out_ = str(env_to_path(hydra_opts.output_path)) | |
| config_path = Path(out_) / "opts.yaml" | |
| if not config_path.exists(): | |
| config_path = None | |
| print("WARNING: could not reuse the opts in {}".format(out_)) | |
| default = args.default or Path(__file__).parent / "shared/trainer/defaults.yaml" | |
| # ----------------------- | |
| # ----- Load opts ----- | |
| # ----------------------- | |
| opts = load_opts(config_path, default=default, commandline_opts=hydra_opts) | |
| if args.resume: | |
| opts.train.resume = True | |
| opts.jobID = os.environ.get("SLURM_JOBID") | |
| opts.slurm_partition = os.environ.get("SLURM_JOB_PARTITION") | |
| opts.output_path = str(env_to_path(opts.output_path)) | |
| print("Config output_path:", opts.output_path) | |
| exp = comet_previous_id = None | |
| # ------------------------------- | |
| # ----- Check output_path ----- | |
| # ------------------------------- | |
| # Auto-continue if same slurm job ID (=job was requeued) | |
| if not opts.train.resume and opts.train.auto_resume: | |
| print("\n\nTrying to auto-resume...") | |
| existing_path = find_existing_training(opts) | |
| if existing_path is not None and existing_path.exists(): | |
| auto_resumed["original output_path"] = str(opts.output_path) | |
| auto_resumed["existing_path"] = str(existing_path) | |
| opts.train.resume = True | |
| opts.output_path = str(existing_path) | |
| # Still not resuming: creating new output path | |
| if not opts.train.resume: | |
| opts.output_path = str(get_increased_path(opts.output_path)) | |
| Path(opts.output_path).mkdir(parents=True, exist_ok=True) | |
| # Copy the opts's sbatch_file to output_path | |
| copy_run_files(opts) | |
| # store git hash | |
| opts.git_hash = get_git_revision_hash() | |
| opts.git_branch = get_git_branch() | |
| if not args.no_comet: | |
| # ---------------------------------- | |
| # ----- Set Comet Experiment ----- | |
| # ---------------------------------- | |
| if opts.train.resume: | |
| # Is resuming: get existing comet exp id | |
| assert Path(opts.output_path).exists(), "Output_path does not exist" | |
| comet_previous_id = get_existing_comet_id(opts.output_path) | |
| # Continue existing experiment | |
| if comet_previous_id is None: | |
| print("WARNING could not retreive previous comet id") | |
| print(f"from {opts.output_path}") | |
| else: | |
| print("Continuing previous experiment", comet_previous_id) | |
| auto_resumed["continuing exp id"] = comet_previous_id | |
| exp = ExistingExperiment( | |
| previous_experiment=comet_previous_id, **comet_kwargs | |
| ) | |
| print("Comet Experiment resumed") | |
| if exp is None: | |
| # Create new experiment | |
| print("Starting new experiment") | |
| exp = Experiment(project_name="climategan", **comet_kwargs) | |
| exp.log_asset_folder( | |
| str(Path(__file__).parent / "climategan"), | |
| recursive=True, | |
| log_file_name=True, | |
| ) | |
| exp.log_asset(str(Path(__file__))) | |
| # Log note | |
| if args.note: | |
| exp.log_parameter("note", args.note) | |
| # Merge and log tags | |
| if args.comet_tags or opts.comet.tags: | |
| tags = set([f"branch:{opts.git_branch}"]) | |
| if args.comet_tags: | |
| tags.update(args.comet_tags) | |
| if opts.comet.tags: | |
| tags.update(opts.comet.tags) | |
| opts.comet.tags = list(tags) | |
| print("Logging to comet.ml with tags", opts.comet.tags) | |
| exp.add_tags(opts.comet.tags) | |
| # Log all opts | |
| exp.log_parameters(flatten_opts(opts)) | |
| if auto_resumed: | |
| exp.log_text("\n".join(f"{k:20}: {v}" for k, v in auto_resumed.items())) | |
| # allow some time for comet to get its url | |
| sleep(1) | |
| # Save comet exp url | |
| url_path = get_increased_path(Path(opts.output_path) / "comet_url.txt") | |
| with open(url_path, "w") as f: | |
| f.write(exp.url) | |
| # Save config file | |
| opts_path = get_increased_path(Path(opts.output_path) / "opts.yaml") | |
| with (opts_path).open("w") as f: | |
| yaml.safe_dump(opts.to_dict(), f) | |
| pprint("Running model in", opts.output_path) | |
| # ------------------- | |
| # ----- Train ----- | |
| # ------------------- | |
| trainer = Trainer(opts, comet_exp=exp, verbose=1) | |
| trainer.logger.time.start_time = time() | |
| trainer.setup() | |
| trainer.train() | |
| # ----------------------------- | |
| # ----- End of training ----- | |
| # ----------------------------- | |
| pprint("Done training") | |
| kill_job(opts.jobID) | |
| if __name__ == "__main__": | |
| main() | |