Spaces:
Runtime error
Runtime error
import os, ntpath | |
import torch | |
from collections import OrderedDict | |
from util import util | |
from . import base_function | |
from abc import abstractmethod | |
class BaseModel(): | |
"""This class is an abstract base class for models""" | |
def __init__(self, opt): | |
"""Initialize the BaseModel class""" | |
self.opt = opt | |
self.gpu_ids = opt.gpu_ids | |
self.isTrain = opt.isTrain | |
self.device = torch.device('cuda') if self.gpu_ids else torch.device('cpu') | |
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) | |
self.loss_names = [] | |
self.model_names = [] | |
self.visual_names = [] | |
self.value_names = [] | |
self.image_paths = [] | |
self.optimizers = [] | |
self.schedulers = [] | |
self.metric = 0 # used for learning rate policy 'plateau' | |
def name(self): | |
return 'BaseModel' | |
def modify_options(parser, is_train): | |
"""Add new options and rewrite default values for existing options""" | |
return parser | |
def set_input(self, input): | |
"""Unpack input data from the dataloader and perform necessary pre-processing steps""" | |
pass | |
def forward(self): | |
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" | |
pass | |
def optimize_parameters(self): | |
"""Calculate losses, gradients, and update network weights; called in every training iteration""" | |
pass | |
def setup(self, opt): | |
"""Load networks, create schedulers""" | |
if self.isTrain: | |
self.schedulers = [base_function.get_scheduler(optimizer, opt) for optimizer in self.optimizers] | |
if not self.isTrain or opt.continue_train: | |
load_suffix = '%d' % opt.which_iter if opt.which_iter > 0 else opt.epoch | |
self.load_networks(load_suffix) | |
self.print_networks() | |
def parallelize(self): | |
for name in self.model_names: | |
if isinstance(name, str): | |
net = getattr(self, 'net' + name) | |
net.to(self.device) | |
if len(self.opt.gpu_ids) > 0: | |
setattr(self, 'net' + name, torch.nn.parallel.DataParallel(net, self.opt.gpu_ids)) | |
def eval(self): | |
"""Make models eval mode during test time""" | |
for name in self.model_names: | |
if isinstance(name, str): | |
net = getattr(self, 'net' + name) | |
net.eval() | |
def log_imgs(self): | |
"""visualize the image during the training""" | |
pass | |
def test(self): | |
"""Forward function used in test time""" | |
with torch.no_grad(): | |
self.forward() | |
def get_image_paths(self): | |
""" Return image paths that are used to load current data""" | |
return self.image_paths | |
def update_learning_rate(self): | |
"""Update learning rates for all the networks; called at the end of every epoch""" | |
for scheduler in self.schedulers: | |
if self.opt.lr_policy == 'plateau': | |
scheduler.step(self.metric) | |
else: | |
scheduler.step() | |
lr = self.optimizers[0].param_groups[0]['lr'] | |
print('learning rate = %.7f' % lr) | |
def get_current_losses(self): | |
"""Return training loss""" | |
errors_ret = OrderedDict() | |
for name in self.loss_names: | |
if isinstance(name, str): | |
try: | |
errors_ret[name] = float(getattr(self, 'loss_' + name)) | |
except: | |
pass | |
return errors_ret | |
def get_current_visuals(self): | |
"""Return visualization examples""" | |
visual_ret = OrderedDict() | |
for name in self.visual_names: | |
if isinstance(name, str): | |
value = getattr(self, name) | |
if isinstance(value, list): | |
visual_ret[name] = value[-1] | |
else: | |
visual_ret[name] = value | |
return visual_ret | |
def save_networks(self, epoch, save_path=None): | |
"""Save all the networks to the disk.""" | |
save_path = save_path if save_path!= None else self.save_dir | |
for name in self.model_names: | |
if isinstance(name, str): | |
filename = '%s_net_%s.pth' % (epoch, name) | |
path = os.path.join(save_path, filename) | |
net = getattr(self, 'net' + name) | |
if len(self.gpu_ids) > 0 and torch.cuda.is_available(): | |
torch.save(net.module.cpu().state_dict(), path) | |
net.cuda(self.gpu_ids[0]) | |
else: | |
torch.save(net.cpu().state_dict(), path) | |
def load_networks(self, epoch, save_path=None): | |
"""Load all the networks from the disk""" | |
save_path = save_path if save_path != None else self.save_dir | |
for name in self.model_names: | |
if isinstance(name, str): | |
filename = '%s_net_%s.pth' % (epoch, name) | |
path = os.path.join(save_path, filename) | |
net = getattr(self, 'net' + name) | |
if isinstance(net, torch.nn.DataParallel): | |
net = net.module | |
print('loading the model from %s' % path) | |
try: | |
state_dict = torch.load(path, map_location=str(self.device)) | |
if hasattr(state_dict, '_metadata'): | |
del state_dict._metadata | |
net.load_state_dict(state_dict) | |
except: | |
print('Pretrained network %s is unmatched' % name) | |
if len(self.gpu_ids) > 0 and torch.cuda.is_available(): | |
net.cuda() | |
def print_networks(self): | |
"""Print the total number of parameters in the network and (if verbose) network architecture""" | |
print('---------- Networks initialized -------------') | |
for name in self.model_names: | |
if isinstance(name, str): | |
net = getattr(self, 'net' + name) | |
num_params = 0 | |
for param in net.parameters(): | |
num_params += param.numel() | |
print(net) | |
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) | |
print('-----------------------------------------------') | |
def set_requires_grad(self, nets, requires_grad=False): | |
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations | |
Parameters: | |
nets (network list) -- a list of networks | |
requires_grad (bool) -- whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
def save_results(self, save_data, path=None, data_name='none'): | |
"""save the training or testing results to disk""" | |
img_paths = self.get_image_paths() | |
for i in range(save_data.size(0)): | |
short_path = ntpath.basename(img_paths[i]) # get image path | |
name = os.path.splitext(short_path)[0] | |
img_name = '%s_%s.png' % (name, data_name) | |
util.mkdir(path) | |
img_path = os.path.join(path, img_name) | |
img_numpy = util.tensor2im(save_data[i].unsqueeze(0)) | |
util.save_image(img_numpy, img_path) |