import os import torch from collections import OrderedDict from . import networks class BaseModel(): # modify parser to add command line options, # and also change the default values if needed @staticmethod def modify_commandline_options(parser, is_train): return parser def name(self): return 'BaseModel' def initialize(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) if opt.resize_or_crop != 'scale_width': torch.backends.cudnn.benchmark = True self.loss_names = [] self.model_names = [] self.visual_names = [] self.image_paths = [] # self.optimizers = [] def set_input(self, input): self.input = input def forward(self): pass # load and print networks; create schedulers def setup(self, opt, parser=None): if self.isTrain: self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) self.print_networks(opt.verbose) # make models eval mode during test time def eval(self): for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) net.eval() # used in test time, wrapping `forward` in no_grad() so we don't save # intermediate steps for backprop def test(self): with torch.no_grad(): self.forward() # get image paths def get_image_paths(self): return self.image_paths def optimize_parameters(self): pass # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % lr) # return visualization images. train.py will display these images, and save the images to a html def get_current_visuals(self): visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): visual_ret[name] = getattr(self, name) return visual_ret # return traning losses/errors. train.py will print out these errors as debugging information def get_current_losses(self): errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): # float(...) works for both scalar tensor and float number errors_ret[name] = float(getattr(self, 'loss_' + name)) return errors_ret # save models to the disk def save_networks(self, which_epoch): for name in self.model_names: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (which_epoch, name) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'running_mean' or key == 'running_var'): if getattr(module, key) is None: state_dict.pop('.'.join(keys)) if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'num_batches_tracked'): state_dict.pop('.'.join(keys)) else: self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) # load models from the disk def load_networks(self, which_epoch): for name in self.model_names: if isinstance(name, str): load_filename = '%s_net_%s.pth' % (which_epoch, name) load_path = os.path.join(self.save_dir, load_filename) net = getattr(self, 'net' + name) if isinstance(net, torch.nn.DataParallel): net = net.module # print('loading the model from %s' % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device if not os.path.exists(load_path): continue state_dict = torch.load(load_path, map_location=str(self.device)) if hasattr(state_dict, '_metadata'): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) model_dict = net.state_dict() # new_dict = {k: v for k, v in state_dict.items() if k in model_dict.keys()} new_dict = {} for k, v in state_dict.items(): if k in model_dict.keys(): # print(k) # if k == 'sff_branch.0.sff0.MaskModel.0.weight' or k =='sff_branch.0.sff1.MaskModel.0.weight' or k == 'sff_branch.1.sff0.MaskModel.0.weight' or k =='sff_branch.1.sff1.MaskModel.0.weight' or k == 'sff_branch.2.sff0.MaskModel.0.weight' or k =='sff_branch.2.sff1.MaskModel.0.weight' or k == 'sff_branch.3.sff0.MaskModel.0.weight' or k =='sff_branch.3.sff1.MaskModel.0.weight' or k == 'sff_branch.4.MaskModel.0.weight' : # continue # if 'Mask_CModel.model' in k: # continue new_dict[k] = v model_dict.update(new_dict) net.load_state_dict(model_dict) # print network information def print_networks(self, verbose): # print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) num_params = 0 for param in net.parameters(): num_params += param.numel() # if verbose: # print(net) # print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) # print('-----------------------------------------------') # set requies_grad=Fasle to avoid computation def set_requires_grad(self, nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad