|
|
|
|
|
|
|
|
|
|
|
|
|
"""Helpers for managing the run/training loop.""" |
|
|
|
import datetime |
|
import json |
|
import os |
|
import pprint |
|
import time |
|
import types |
|
|
|
from typing import Any |
|
|
|
from . import submit |
|
|
|
|
|
_run_context = None |
|
|
|
class RunContext(object): |
|
"""Helper class for managing the run/training loop. |
|
|
|
The context will hide the implementation details of a basic run/training loop. |
|
It will set things up properly, tell if run should be stopped, and then cleans up. |
|
User should call update periodically and use should_stop to determine if run should be stopped. |
|
|
|
Args: |
|
submit_config: The SubmitConfig that is used for the current run. |
|
config_module: (deprecated) The whole config module that is used for the current run. |
|
""" |
|
|
|
def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None): |
|
global _run_context |
|
|
|
assert _run_context is None |
|
_run_context = self |
|
self.submit_config = submit_config |
|
self.should_stop_flag = False |
|
self.has_closed = False |
|
self.start_time = time.time() |
|
self.last_update_time = time.time() |
|
self.last_update_interval = 0.0 |
|
self.progress_monitor_file_path = None |
|
|
|
|
|
if config_module is not None: |
|
print("RunContext.config_module parameter support has been removed.") |
|
|
|
|
|
self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} |
|
with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: |
|
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) |
|
|
|
def __enter__(self) -> "RunContext": |
|
return self |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
self.close() |
|
|
|
def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: |
|
"""Do general housekeeping and keep the state of the context up-to-date. |
|
Should be called often enough but not in a tight loop.""" |
|
assert not self.has_closed |
|
|
|
self.last_update_interval = time.time() - self.last_update_time |
|
self.last_update_time = time.time() |
|
|
|
if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): |
|
self.should_stop_flag = True |
|
|
|
def should_stop(self) -> bool: |
|
"""Tell whether a stopping condition has been triggered one way or another.""" |
|
return self.should_stop_flag |
|
|
|
def get_time_since_start(self) -> float: |
|
"""How much time has passed since the creation of the context.""" |
|
return time.time() - self.start_time |
|
|
|
def get_time_since_last_update(self) -> float: |
|
"""How much time has passed since the last call to update.""" |
|
return time.time() - self.last_update_time |
|
|
|
def get_last_update_interval(self) -> float: |
|
"""How much time passed between the previous two calls to update.""" |
|
return self.last_update_interval |
|
|
|
def close(self) -> None: |
|
"""Close the context and clean up. |
|
Should only be called once.""" |
|
if not self.has_closed: |
|
|
|
self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") |
|
with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: |
|
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) |
|
self.has_closed = True |
|
|
|
|
|
global _run_context |
|
if _run_context is self: |
|
_run_context = None |
|
|
|
@staticmethod |
|
def get(): |
|
import dnnlib |
|
if _run_context is not None: |
|
return _run_context |
|
return RunContext(dnnlib.submit_config) |
|
|