|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from typing import Optional |
|
|
|
import torch.optim |
|
|
|
from accelerate import Accelerator |
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase |
|
from pytorch3d.implicitron.tools import model_io |
|
from pytorch3d.implicitron.tools.config import ( |
|
registry, |
|
ReplaceableBase, |
|
run_auto_creation, |
|
) |
|
from pytorch3d.implicitron.tools.stats import Stats |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ModelFactoryBase(ReplaceableBase): |
|
|
|
resume: bool = True |
|
|
|
def __call__(self, **kwargs) -> ImplicitronModelBase: |
|
""" |
|
Initialize the model (possibly from a previously saved state). |
|
|
|
Returns: An instance of ImplicitronModelBase. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def load_stats(self, **kwargs) -> Stats: |
|
""" |
|
Initialize or load a Stats object. |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
@registry.register |
|
class ImplicitronModelFactory(ModelFactoryBase): |
|
""" |
|
A factory class that initializes an implicit rendering model. |
|
|
|
Members: |
|
model: An ImplicitronModelBase object. |
|
resume: If True, attempt to load the last checkpoint from `exp_dir` |
|
passed to __call__. Failure to do so will return a model with ini- |
|
tial weights unless `force_resume` is True. |
|
resume_epoch: If `resume` is True: Resume a model at this epoch, or if |
|
`resume_epoch` <= 0, then resume from the latest checkpoint. |
|
force_resume: If True, throw a FileNotFoundError if `resume` is True but |
|
a model checkpoint cannot be found. |
|
|
|
""" |
|
|
|
model: ImplicitronModelBase |
|
model_class_type: str = "GenericModel" |
|
resume: bool = True |
|
resume_epoch: int = -1 |
|
force_resume: bool = False |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
def __call__( |
|
self, |
|
exp_dir: str, |
|
accelerator: Optional[Accelerator] = None, |
|
) -> ImplicitronModelBase: |
|
""" |
|
Returns an instance of `ImplicitronModelBase`, possibly loaded from a |
|
checkpoint (if self.resume, self.resume_epoch specify so). |
|
|
|
Args: |
|
exp_dir: Root experiment directory. |
|
accelerator: An Accelerator object. |
|
|
|
Returns: |
|
model: The model with optionally loaded weights from checkpoint |
|
|
|
Raise: |
|
FileNotFoundError if `force_resume` is True but checkpoint not found. |
|
""" |
|
|
|
if hasattr(self.model, "log_vars"): |
|
log_vars = list(self.model.log_vars) |
|
else: |
|
log_vars = ["objective"] |
|
|
|
if self.resume_epoch > 0: |
|
|
|
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch) |
|
if not os.path.isfile(model_path): |
|
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.") |
|
else: |
|
|
|
model_path = model_io.find_last_checkpoint(exp_dir) |
|
|
|
if model_path is not None: |
|
logger.info(f"Found previous model {model_path}") |
|
if self.force_resume or self.resume: |
|
logger.info("Resuming.") |
|
|
|
map_location = None |
|
if accelerator is not None and not accelerator.is_local_main_process: |
|
map_location = { |
|
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index |
|
} |
|
model_state_dict = torch.load( |
|
model_io.get_model_path(model_path), map_location=map_location |
|
) |
|
|
|
try: |
|
self.model.load_state_dict(model_state_dict, strict=True) |
|
except RuntimeError as e: |
|
logger.error(e) |
|
logger.info( |
|
"Cannot load state dict in strict mode! -> trying non-strict" |
|
) |
|
self.model.load_state_dict(model_state_dict, strict=False) |
|
self.model.log_vars = log_vars |
|
else: |
|
logger.info("Not resuming -> starting from scratch.") |
|
elif self.force_resume: |
|
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!") |
|
|
|
return self.model |
|
|