Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| from .config import parse_structured | |
| from .misc import get_device, load_module_weights | |
| from .typing import * | |
| class Configurable: | |
| class Config: | |
| pass | |
| def __init__(self, cfg: Optional[dict] = None) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(self.Config, cfg) | |
| class Updateable: | |
| def do_update_step( | |
| self, epoch: int, global_step: int, on_load_weights: bool = False | |
| ): | |
| for attr in self.__dir__(): | |
| if attr.startswith("_"): | |
| continue | |
| try: | |
| module = getattr(self, attr) | |
| except: | |
| continue # ignore attributes like property, which can't be retrived using getattr? | |
| if isinstance(module, Updateable): | |
| module.do_update_step( | |
| epoch, global_step, on_load_weights=on_load_weights | |
| ) | |
| self.update_step(epoch, global_step, on_load_weights=on_load_weights) | |
| def do_update_step_end(self, epoch: int, global_step: int): | |
| for attr in self.__dir__(): | |
| if attr.startswith("_"): | |
| continue | |
| try: | |
| module = getattr(self, attr) | |
| except: | |
| continue # ignore attributes like property, which can't be retrived using getattr? | |
| if isinstance(module, Updateable): | |
| module.do_update_step_end(epoch, global_step) | |
| self.update_step_end(epoch, global_step) | |
| def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
| # override this method to implement custom update logic | |
| # if on_load_weights is True, you should be careful doing things related to model evaluations, | |
| # as the models and tensors are not guarenteed to be on the same device | |
| pass | |
| def update_step_end(self, epoch: int, global_step: int): | |
| pass | |
| def update_if_possible(module: Any, epoch: int, global_step: int) -> None: | |
| if isinstance(module, Updateable): | |
| module.do_update_step(epoch, global_step) | |
| def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: | |
| if isinstance(module, Updateable): | |
| module.do_update_step_end(epoch, global_step) | |
| class BaseObject(Updateable): | |
| class Config: | |
| pass | |
| cfg: Config # add this to every subclass of BaseObject to enable static type checking | |
| def __init__( | |
| self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs | |
| ) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(self.Config, cfg) | |
| self.device = get_device() | |
| self.configure(*args, **kwargs) | |
| def configure(self, *args, **kwargs) -> None: | |
| pass | |
| class BaseModule(nn.Module, Updateable): | |
| class Config: | |
| weights: Optional[str] = None | |
| cfg: Config # add this to every subclass of BaseModule to enable static type checking | |
| def __init__( | |
| self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs | |
| ) -> None: | |
| super().__init__() | |
| self.cfg = parse_structured(self.Config, cfg) | |
| self.device = get_device() | |
| self._non_modules = {} | |
| self.configure(*args, **kwargs) | |
| if self.cfg.weights is not None: | |
| # format: path/to/weights:module_name | |
| weights_path, module_name = self.cfg.weights.split(":") | |
| state_dict, epoch, global_step = load_module_weights( | |
| weights_path, module_name=module_name, map_location="cpu" | |
| ) | |
| self.load_state_dict(state_dict) | |
| self.do_update_step( | |
| epoch, global_step, on_load_weights=True | |
| ) # restore states | |
| def configure(self, *args, **kwargs) -> None: | |
| pass | |
| def register_non_module(self, name: str, module: nn.Module) -> None: | |
| # non-modules won't be treated as model parameters | |
| self._non_modules[name] = module | |
| def non_module(self, name: str): | |
| return self._non_modules.get(name, None) | |