Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 5,079 Bytes
			
			| 2ac1c2d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import os
import shutil
import subprocess
import pytorch_lightning
from .config import dump_config
from .misc import parse_version
if parse_version(pytorch_lightning.__version__) > parse_version("1.8"):
    from pytorch_lightning.callbacks import Callback
else:
    from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
class VersionedCallback(Callback):
    def __init__(self, save_root, version=None, use_version=True):
        self.save_root = save_root
        self._version = version
        self.use_version = use_version
    @property
    def version(self) -> int:
        """Get the experiment version.
        Returns:
            The experiment version if specified else the next version.
        """
        if self._version is None:
            self._version = self._get_next_version()
        return self._version
    def _get_next_version(self):
        existing_versions = []
        if os.path.isdir(self.save_root):
            for f in os.listdir(self.save_root):
                bn = os.path.basename(f)
                if bn.startswith("version_"):
                    dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "")
                    existing_versions.append(int(dir_ver))
        if len(existing_versions) == 0:
            return 0
        return max(existing_versions) + 1
    @property
    def savedir(self):
        if not self.use_version:
            return self.save_root
        return os.path.join(
            self.save_root,
            (
                self.version
                if isinstance(self.version, str)
                else f"version_{self.version}"
            ),
        )
class CodeSnapshotCallback(VersionedCallback):
    def __init__(self, save_root, version=None, use_version=True):
        super().__init__(save_root, version, use_version)
    def get_file_list(self):
        return [
            b.decode()
            for b in set(
                subprocess.check_output(
                    'git ls-files -- ":!:load/*"', shell=True
                ).splitlines()
            )
            | set(  # hard code, TODO: use config to exclude folders or files
                subprocess.check_output(
                    "git ls-files --others --exclude-standard", shell=True
                ).splitlines()
            )
        ]
    @rank_zero_only
    def save_code_snapshot(self):
        os.makedirs(self.savedir, exist_ok=True)
        for f in self.get_file_list():
            if not os.path.exists(f) or os.path.isdir(f):
                continue
            os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
            shutil.copyfile(f, os.path.join(self.savedir, f))
    def on_fit_start(self, trainer, pl_module):
        try:
            self.save_code_snapshot()
        except:
            rank_zero_warn(
                "Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
            )
class ConfigSnapshotCallback(VersionedCallback):
    def __init__(self, config_path, config, save_root, version=None, use_version=True):
        super().__init__(save_root, version, use_version)
        self.config_path = config_path
        self.config = config
    @rank_zero_only
    def save_config_snapshot(self):
        os.makedirs(self.savedir, exist_ok=True)
        dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config)
        shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml"))
    def on_fit_start(self, trainer, pl_module):
        self.save_config_snapshot()
class CustomProgressBar(TQDMProgressBar):
    def get_metrics(self, *args, **kwargs):
        # don't show the version number
        items = super().get_metrics(*args, **kwargs)
        items.pop("v_num", None)
        return items
class ProgressCallback(Callback):
    def __init__(self, save_path):
        super().__init__()
        self.save_path = save_path
        self._file_handle = None
    @property
    def file_handle(self):
        if self._file_handle is None:
            self._file_handle = open(self.save_path, "w")
        return self._file_handle
    @rank_zero_only
    def write(self, msg: str) -> None:
        self.file_handle.seek(0)
        self.file_handle.truncate()
        self.file_handle.write(msg)
        self.file_handle.flush()
    @rank_zero_only
    def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
        self.write(
            f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%"
        )
    @rank_zero_only
    def on_validation_start(self, trainer, pl_module):
        self.write(f"Rendering validation image ...")
    @rank_zero_only
    def on_test_start(self, trainer, pl_module):
        self.write(f"Rendering video ...")
    @rank_zero_only
    def on_predict_start(self, trainer, pl_module):
        self.write(f"Exporting mesh assets ...")
 |