| | from pytorch_lightning.callbacks import Callback |
| | import os |
| | import shutil |
| | from omegaconf import OmegaConf |
| |
|
| | class SetupCallback(Callback): |
| | def __init__(self, now, logdir, ckptdir, cfgdir, config, argv_content=None): |
| | super().__init__() |
| | self.now = now |
| | self.logdir = logdir |
| | self.ckptdir = ckptdir |
| | self.cfgdir = cfgdir |
| | self.config = config |
| | |
| | self.argv_content = argv_content |
| |
|
| | |
| | def on_fit_start(self, trainer, pl_module): |
| | |
| | os.makedirs(self.logdir, exist_ok=True) |
| | os.makedirs(self.ckptdir, exist_ok=True) |
| | os.makedirs(self.cfgdir, exist_ok=True) |
| |
|
| | print("Project config") |
| | print(OmegaConf.to_yaml(self.config)) |
| | OmegaConf.save(self.config, |
| | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) |
| | |
| | with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f: |
| | f.write(str(self.argv_content)) |
| |
|
| | class BackupCodeCallback(Callback): |
| | def __init__(self, source_dir, backup_dir, ignore_patterns=None): |
| | super().__init__() |
| | self.source_dir = source_dir |
| | self.backup_dir = backup_dir |
| | self.ignore_patterns = ignore_patterns |
| |
|
| | def on_train_start(self, trainer, pl_module): |
| | try: |
| | os.makedirs(self.backup_dir, exist_ok=True) |
| | if os.path.exists(self.backup_dir+'/code'): |
| | shutil.rmtree(self.backup_dir+'/code') |
| | shutil.copytree(self.source_dir, self.backup_dir+'/code', ignore=self.ignore_patterns) |
| |
|
| | print(f"Code file backed up to {self.backup_dir}") |
| | except: |
| | print(f"Fail in copying file backed up to {self.backup_dir}") |