|
"""This script defines the base network model for Deep3DFaceRecon_pytorch |
|
""" |
|
|
|
import os |
|
import numpy as np |
|
import torch |
|
from collections import OrderedDict |
|
from abc import ABC, abstractmethod |
|
from . import networks |
|
|
|
|
|
class BaseModel(ABC): |
|
"""This class is an abstract base class (ABC) for models. |
|
To create a subclass, you need to implement the following five functions: |
|
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). |
|
-- <set_input>: unpack data from dataset and apply preprocessing. |
|
-- <forward>: produce intermediate results. |
|
-- <optimize_parameters>: calculate losses, gradients, and update network weights. |
|
-- <modify_commandline_options>: (optionally) add model-specific options and set default options. |
|
""" |
|
|
|
def __init__(self, opt): |
|
"""Initialize the BaseModel class. |
|
|
|
Parameters: |
|
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions |
|
|
|
When creating your custom class, you need to implement your own initialization. |
|
In this fucntion, you should first call <BaseModel.__init__(self, opt)> |
|
Then, you need to define four lists: |
|
-- self.loss_names (str list): specify the training losses that you want to plot and save. |
|
-- self.model_names (str list): specify the images that you want to display and save. |
|
-- self.visual_names (str list): define networks used in our training. |
|
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. |
|
""" |
|
self.opt = opt |
|
self.isTrain = False |
|
self.device = torch.device('cpu') |
|
self.save_dir = " " |
|
self.loss_names = [] |
|
self.model_names = [] |
|
self.visual_names = [] |
|
self.parallel_names = [] |
|
self.optimizers = [] |
|
self.image_paths = [] |
|
self.metric = 0 |
|
|
|
@staticmethod |
|
def dict_grad_hook_factory(add_func=lambda x: x): |
|
saved_dict = dict() |
|
|
|
def hook_gen(name): |
|
def grad_hook(grad): |
|
saved_vals = add_func(grad) |
|
saved_dict[name] = saved_vals |
|
return grad_hook |
|
return hook_gen, saved_dict |
|
|
|
@staticmethod |
|
def modify_commandline_options(parser, is_train): |
|
"""Add new model-specific options, and rewrite default values for existing options. |
|
|
|
Parameters: |
|
parser -- original option parser |
|
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. |
|
|
|
Returns: |
|
the modified parser. |
|
""" |
|
return parser |
|
|
|
@abstractmethod |
|
def set_input(self, input): |
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps. |
|
|
|
Parameters: |
|
input (dict): includes the data itself and its metadata information. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def forward(self): |
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" |
|
pass |
|
|
|
@abstractmethod |
|
def optimize_parameters(self): |
|
"""Calculate losses, gradients, and update network weights; called in every training iteration""" |
|
pass |
|
|
|
def setup(self, opt): |
|
"""Load and print networks; create schedulers |
|
|
|
Parameters: |
|
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions |
|
""" |
|
if self.isTrain: |
|
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] |
|
|
|
if not self.isTrain or opt.continue_train: |
|
load_suffix = opt.epoch |
|
self.load_networks(load_suffix) |
|
|
|
|
|
|
|
|
|
def parallelize(self, convert_sync_batchnorm=True): |
|
if not self.opt.use_ddp: |
|
for name in self.parallel_names: |
|
if isinstance(name, str): |
|
module = getattr(self, name) |
|
setattr(self, name, module.to(self.device)) |
|
else: |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
module = getattr(self, name) |
|
if convert_sync_batchnorm: |
|
module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) |
|
setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), |
|
device_ids=[self.device.index], |
|
find_unused_parameters=True, broadcast_buffers=True)) |
|
|
|
|
|
for name in self.parallel_names: |
|
if isinstance(name, str) and name not in self.model_names: |
|
module = getattr(self, name) |
|
setattr(self, name, module.to(self.device)) |
|
|
|
|
|
if self.opt.phase != 'test': |
|
if self.opt.continue_train: |
|
for optim in self.optimizers: |
|
for state in optim.state.values(): |
|
for k, v in state.items(): |
|
if isinstance(v, torch.Tensor): |
|
state[k] = v.to(self.device) |
|
|
|
def data_dependent_initialize(self, data): |
|
pass |
|
|
|
def train(self): |
|
"""Make models train mode""" |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, name) |
|
net.train() |
|
|
|
def eval(self): |
|
"""Make models eval mode""" |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, name) |
|
net.eval() |
|
|
|
def test(self): |
|
"""Forward function used in test time. |
|
|
|
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop |
|
It also calls <compute_visuals> to produce additional visualization results |
|
""" |
|
with torch.no_grad(): |
|
self.forward() |
|
self.compute_visuals() |
|
|
|
def compute_visuals(self): |
|
"""Calculate additional output images for visdom and HTML visualization""" |
|
pass |
|
|
|
def get_image_paths(self, name='A'): |
|
""" Return image paths that are used to load current data""" |
|
return self.image_paths if name =='A' else self.image_paths_B |
|
|
|
def update_learning_rate(self): |
|
"""Update learning rates for all the networks; called at the end of every epoch""" |
|
for scheduler in self.schedulers: |
|
if self.opt.lr_policy == 'plateau': |
|
scheduler.step(self.metric) |
|
else: |
|
scheduler.step() |
|
|
|
lr = self.optimizers[0].param_groups[0]['lr'] |
|
print('learning rate = %.7f' % lr) |
|
|
|
def get_current_visuals(self): |
|
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" |
|
visual_ret = OrderedDict() |
|
for name in self.visual_names: |
|
if isinstance(name, str): |
|
visual_ret[name] = getattr(self, name)[:, :3, ...] |
|
return visual_ret |
|
|
|
def get_current_losses(self): |
|
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" |
|
errors_ret = OrderedDict() |
|
for name in self.loss_names: |
|
if isinstance(name, str): |
|
errors_ret[name] = float(getattr(self, 'loss_' + name)) |
|
return errors_ret |
|
|
|
def save_networks(self, epoch): |
|
"""Save all the networks to the disk. |
|
|
|
Parameters: |
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
|
""" |
|
if not os.path.isdir(self.save_dir): |
|
os.makedirs(self.save_dir) |
|
|
|
save_filename = 'epoch_%s.pth' % (epoch) |
|
save_path = os.path.join(self.save_dir, save_filename) |
|
|
|
save_dict = {} |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, name) |
|
if isinstance(net, torch.nn.DataParallel) or isinstance(net, |
|
torch.nn.parallel.DistributedDataParallel): |
|
net = net.module |
|
save_dict[name] = net.state_dict() |
|
|
|
|
|
for i, optim in enumerate(self.optimizers): |
|
save_dict['opt_%02d'%i] = optim.state_dict() |
|
|
|
for i, sched in enumerate(self.schedulers): |
|
save_dict['sched_%02d'%i] = sched.state_dict() |
|
|
|
torch.save(save_dict, save_path) |
|
|
|
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): |
|
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" |
|
key = keys[i] |
|
if i + 1 == len(keys): |
|
if module.__class__.__name__.startswith('InstanceNorm') and \ |
|
(key == 'running_mean' or key == 'running_var'): |
|
if getattr(module, key) is None: |
|
state_dict.pop('.'.join(keys)) |
|
if module.__class__.__name__.startswith('InstanceNorm') and \ |
|
(key == 'num_batches_tracked'): |
|
state_dict.pop('.'.join(keys)) |
|
else: |
|
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) |
|
|
|
def load_networks(self, epoch): |
|
"""Load all the networks from the disk. |
|
|
|
Parameters: |
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
|
""" |
|
if self.opt.isTrain and self.opt.pretrained_name is not None: |
|
load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) |
|
else: |
|
load_dir = self.save_dir |
|
load_filename = 'epoch_%s.pth' % (epoch) |
|
load_path = os.path.join(load_dir, load_filename) |
|
state_dict = torch.load(load_path, map_location=self.device) |
|
print('loading the model from %s' % load_path) |
|
|
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, name) |
|
if isinstance(net, torch.nn.DataParallel): |
|
net = net.module |
|
net.load_state_dict(state_dict[name]) |
|
|
|
if self.opt.phase != 'test': |
|
if self.opt.continue_train: |
|
print('loading the optim from %s' % load_path) |
|
for i, optim in enumerate(self.optimizers): |
|
optim.load_state_dict(state_dict['opt_%02d'%i]) |
|
|
|
try: |
|
print('loading the sched from %s' % load_path) |
|
for i, sched in enumerate(self.schedulers): |
|
sched.load_state_dict(state_dict['sched_%02d'%i]) |
|
except: |
|
print('Failed to load schedulers, set schedulers according to epoch count manually') |
|
for i, sched in enumerate(self.schedulers): |
|
sched.last_epoch = self.opt.epoch_count - 1 |
|
|
|
|
|
|
|
|
|
def print_networks(self, verbose): |
|
"""Print the total number of parameters in the network and (if verbose) network architecture |
|
|
|
Parameters: |
|
verbose (bool) -- if verbose: print the network architecture |
|
""" |
|
print('---------- Networks initialized -------------') |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, name) |
|
num_params = 0 |
|
for param in net.parameters(): |
|
num_params += param.numel() |
|
if verbose: |
|
print(net) |
|
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) |
|
print('-----------------------------------------------') |
|
|
|
def set_requires_grad(self, nets, requires_grad=False): |
|
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations |
|
Parameters: |
|
nets (network list) -- a list of networks |
|
requires_grad (bool) -- whether the networks require gradients or not |
|
""" |
|
if not isinstance(nets, list): |
|
nets = [nets] |
|
for net in nets: |
|
if net is not None: |
|
for param in net.parameters(): |
|
param.requires_grad = requires_grad |
|
|
|
def generate_visuals_for_evaluation(self, data, mode): |
|
return {} |
|
|