| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import logging |
| | import os |
| | import psutil |
| | import time |
| | import subprocess |
| | from torch.utils.tensorboard import SummaryWriter |
| | |
| | DEFAULT_TENSORBOARD_PORT = 6006 |
| | |
| |
|
| |
|
| | def get_writer(path: str, tensorboard_port: int | bool, logger: logging.Logger = None): |
| | """ |
| | Sets up a TensorBoard logging and checkpoint directory for PyTorch. |
| | |
| | This function clears the specified directory, creates subdirectories for TensorBoard logs |
| | and model checkpoints, ensuring a clean environment for running new training sessions. |
| | |
| | Args: |
| | path (str): The root directory where TensorBoard logs and checkpoints will be stored. |
| | tensorboard_port (int): The port on which to run the TensorBoard. |
| | logger (logging.Logger): The logger that traces the logging information. |
| | |
| | Returns: |
| | tuple: A tuple containing the TensorBoard SummaryWriter object and the path for checkpoints. |
| | |
| | Example: |
| | >>> tensor_writer, checkpoint_dir = get_writer('/path/to/tensorboard/') |
| | """ |
| | |
| | if tensorboard_port is True: |
| | tensorboard_port = DEFAULT_TENSORBOARD_PORT |
| | elif tensorboard_port is False: |
| | return None, os.path.join(path, 'checkpoints') |
| |
|
| | |
| | logs_path = os.path.join(path, 'logs') |
| | checkpoints_path = os.path.join(path, 'checkpoints') |
| | os.makedirs(logs_path, exist_ok=True) |
| | os.makedirs(checkpoints_path, exist_ok=True) |
| |
|
| | |
| | writer = SummaryWriter(log_dir=logs_path) |
| |
|
| | |
| | if logger is not None: |
| | logger.info(f"TensorBoard logs will be stored in: {logs_path}") |
| | logger.info(f"Model checkpoints will be stored in: {checkpoints_path}") |
| |
|
| | |
| | for conn in psutil.net_connections(kind='inet'): |
| | if conn.laddr.port == tensorboard_port and conn.status == psutil.CONN_LISTEN: |
| | if logger is not None: |
| | logger.warning(f"Killing already running TensorBoard process with PID {conn.pid}") |
| | p = psutil.Process(conn.pid) |
| | p.terminate() |
| | p.wait(timeout=3) |
| | time.sleep(5) |
| | process = subprocess.Popen(f'tensorboard --logdir={logs_path} --host=0.0.0.0 --port={tensorboard_port}', |
| | shell=True) |
| | if logger is not None: |
| | logger.info(f'TensorBoard running at http://0.0.0.0:{tensorboard_port}/ (pid={process.pid})') |
| |
|
| | return writer, checkpoints_path |
| | |
| | |
| | |
| |
|