import torch import torch.nn as nn from utils.misc import get_rank class BaseModel(nn.Module): def __init__(self, config): super().__init__() self.config = config self.rank = get_rank() self.setup() if self.config.get('weights', None): self.load_state_dict(torch.load(self.config.weights)) def setup(self): raise NotImplementedError def update_step(self, epoch, global_step): pass def train(self, mode=True): return super().train(mode=mode) def eval(self): return super().eval() def regularizations(self, out): return {} @torch.no_grad() def export(self, export_config): return {}