Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from utils.misc import get_rank | |
class BaseModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.rank = get_rank() | |
self.setup() | |
if self.config.get('weights', None): | |
self.load_state_dict(torch.load(self.config.weights)) | |
def setup(self): | |
raise NotImplementedError | |
def update_step(self, epoch, global_step): | |
pass | |
def train(self, mode=True): | |
return super().train(mode=mode) | |
def eval(self): | |
return super().eval() | |
def regularizations(self, out): | |
return {} | |
def export(self, export_config): | |
return {} | |