sunshineatnoon
Add application file
1b2a9b1
raw
history blame
No virus
4.42 kB
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
import os
import importlib
from swapae.models.base_model import BaseModel
import torch
from torch.nn.parallel import DataParallel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "swapae.models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
instance.initialize()
multigpu_instance = MultiGPUModelWrapper(opt, instance)
print("model [%s] was created" % type(instance).__name__)
return multigpu_instance
class MultiGPUModelWrapper():
def __init__(self, opt, model: BaseModel):
self.opt = opt
if opt.num_gpus > 0:
model = model.to('cuda:0')
self.parallelized_model = torch.nn.parallel.DataParallel(model)
self.parallelized_model(command="per_gpu_initialize")
self.singlegpu_model = self.parallelized_model.module
self.singlegpu_model(command="per_gpu_initialize")
def get_parameters_for_mode(self, mode):
return self.singlegpu_model.get_parameters_for_mode(mode)
def save(self, total_steps_so_far):
self.singlegpu_model.save(total_steps_so_far)
def __call__(self, *args, **kwargs):
""" Calls are forwarded to __call__ of BaseModel through DataParallel, and corresponding methods specified by |command| will be called. Please see BaseModel.forward() to see how it is done. """
return self.parallelized_model(*args, **kwargs)
class StateVariableStorage():
pass
_state_variables = StateVariableStorage()
_state_variables.fix_noise = False
def fixed_noise():
return _state_variables.fix_noise
def fix_noise(set=True):
_state_variables.fix_noise = set