Spaces:
Runtime error
Runtime error
import logging | |
import os | |
from pathlib import Path | |
from time import sleep, time | |
import hydra | |
import yaml | |
from addict import Dict | |
from comet_ml import ExistingExperiment, Experiment | |
from omegaconf import OmegaConf | |
from climategan.trainer import Trainer | |
from climategan.utils import ( | |
comet_kwargs, | |
copy_run_files, | |
env_to_path, | |
find_existing_training, | |
flatten_opts, | |
get_existing_comet_id, | |
get_git_branch, | |
get_git_revision_hash, | |
get_increased_path, | |
kill_job, | |
load_opts, | |
pprint, | |
) | |
logging.basicConfig() | |
logging.getLogger().setLevel(logging.ERROR) | |
hydra_config_path = Path(__file__).resolve().parent / "shared/trainer/config.yaml" | |
# requires hydra-core==0.11.3 and omegaconf==1.4.1 | |
def main(opts): | |
""" | |
Opts prevalence: | |
1. Load file specified in args.default (or shared/trainer/defaults.yaml | |
if none is provided) | |
2. Update with file specified in args.config (or no update if none is provided) | |
3. Update with parsed command-line arguments | |
e.g. | |
`python train.py args.config=config/large-lr.yaml data.loaders.batch_size=10` | |
loads defaults, overrides with values in large-lr.yaml and sets batch_size to 10 | |
""" | |
# ----------------------------- | |
# ----- Parse arguments ----- | |
# ----------------------------- | |
hydra_opts = Dict(OmegaConf.to_container(opts)) | |
args = hydra_opts.pop("args", None) | |
auto_resumed = {} | |
config_path = args.config | |
if hydra_opts.train.resume: | |
out_ = str(env_to_path(hydra_opts.output_path)) | |
config_path = Path(out_) / "opts.yaml" | |
if not config_path.exists(): | |
config_path = None | |
print("WARNING: could not reuse the opts in {}".format(out_)) | |
default = args.default or Path(__file__).parent / "shared/trainer/defaults.yaml" | |
# ----------------------- | |
# ----- Load opts ----- | |
# ----------------------- | |
opts = load_opts(config_path, default=default, commandline_opts=hydra_opts) | |
if args.resume: | |
opts.train.resume = True | |
opts.jobID = os.environ.get("SLURM_JOBID") | |
opts.slurm_partition = os.environ.get("SLURM_JOB_PARTITION") | |
opts.output_path = str(env_to_path(opts.output_path)) | |
print("Config output_path:", opts.output_path) | |
exp = comet_previous_id = None | |
# ------------------------------- | |
# ----- Check output_path ----- | |
# ------------------------------- | |
# Auto-continue if same slurm job ID (=job was requeued) | |
if not opts.train.resume and opts.train.auto_resume: | |
print("\n\nTrying to auto-resume...") | |
existing_path = find_existing_training(opts) | |
if existing_path is not None and existing_path.exists(): | |
auto_resumed["original output_path"] = str(opts.output_path) | |
auto_resumed["existing_path"] = str(existing_path) | |
opts.train.resume = True | |
opts.output_path = str(existing_path) | |
# Still not resuming: creating new output path | |
if not opts.train.resume: | |
opts.output_path = str(get_increased_path(opts.output_path)) | |
Path(opts.output_path).mkdir(parents=True, exist_ok=True) | |
# Copy the opts's sbatch_file to output_path | |
copy_run_files(opts) | |
# store git hash | |
opts.git_hash = get_git_revision_hash() | |
opts.git_branch = get_git_branch() | |
if not args.no_comet: | |
# ---------------------------------- | |
# ----- Set Comet Experiment ----- | |
# ---------------------------------- | |
if opts.train.resume: | |
# Is resuming: get existing comet exp id | |
assert Path(opts.output_path).exists(), "Output_path does not exist" | |
comet_previous_id = get_existing_comet_id(opts.output_path) | |
# Continue existing experiment | |
if comet_previous_id is None: | |
print("WARNING could not retreive previous comet id") | |
print(f"from {opts.output_path}") | |
else: | |
print("Continuing previous experiment", comet_previous_id) | |
auto_resumed["continuing exp id"] = comet_previous_id | |
exp = ExistingExperiment( | |
previous_experiment=comet_previous_id, **comet_kwargs | |
) | |
print("Comet Experiment resumed") | |
if exp is None: | |
# Create new experiment | |
print("Starting new experiment") | |
exp = Experiment(project_name="climategan", **comet_kwargs) | |
exp.log_asset_folder( | |
str(Path(__file__).parent / "climategan"), | |
recursive=True, | |
log_file_name=True, | |
) | |
exp.log_asset(str(Path(__file__))) | |
# Log note | |
if args.note: | |
exp.log_parameter("note", args.note) | |
# Merge and log tags | |
if args.comet_tags or opts.comet.tags: | |
tags = set([f"branch:{opts.git_branch}"]) | |
if args.comet_tags: | |
tags.update(args.comet_tags) | |
if opts.comet.tags: | |
tags.update(opts.comet.tags) | |
opts.comet.tags = list(tags) | |
print("Logging to comet.ml with tags", opts.comet.tags) | |
exp.add_tags(opts.comet.tags) | |
# Log all opts | |
exp.log_parameters(flatten_opts(opts)) | |
if auto_resumed: | |
exp.log_text("\n".join(f"{k:20}: {v}" for k, v in auto_resumed.items())) | |
# allow some time for comet to get its url | |
sleep(1) | |
# Save comet exp url | |
url_path = get_increased_path(Path(opts.output_path) / "comet_url.txt") | |
with open(url_path, "w") as f: | |
f.write(exp.url) | |
# Save config file | |
opts_path = get_increased_path(Path(opts.output_path) / "opts.yaml") | |
with (opts_path).open("w") as f: | |
yaml.safe_dump(opts.to_dict(), f) | |
pprint("Running model in", opts.output_path) | |
# ------------------- | |
# ----- Train ----- | |
# ------------------- | |
trainer = Trainer(opts, comet_exp=exp, verbose=1) | |
trainer.logger.time.start_time = time() | |
trainer.setup() | |
trainer.train() | |
# ----------------------------- | |
# ----- End of training ----- | |
# ----------------------------- | |
pprint("Done training") | |
kill_job(opts.jobID) | |
if __name__ == "__main__": | |
main() | |