sunshineatnoon
Add application file
1b2a9b1
raw
history blame
No virus
1.53 kB
import torch
import swapae.util as util
from swapae.models.networks.base_network import BaseNetwork
def find_network_using_name(target_network_name, filename):
target_class_name = target_network_name + filename
module_name = 'swapae.models.networks.' + filename
network = util.find_class_in_module(target_class_name, module_name)
assert issubclass(network, BaseNetwork), \
"Class %s should be a subclass of BaseNetwork" % network
return network
def modify_commandline_options(parser, is_train):
opt, _ = parser.parse_known_args()
netE_cls = find_network_using_name(opt.netE, 'encoder')
assert netE_cls is not None
parser = netE_cls.modify_commandline_options(parser, is_train)
netG_cls = find_network_using_name(opt.netG, 'generator')
assert netG_cls is not None
parser = netG_cls.modify_commandline_options(parser, is_train)
netD_cls = find_network_using_name(opt.netD, 'discriminator')
parser = netD_cls.modify_commandline_options(parser, is_train)
if opt.netPatchD is not None:
netD_cls = find_network_using_name(opt.netPatchD, 'patch_discriminator')
assert netD_cls is not None
parser = netD_cls.modify_commandline_options(parser, is_train)
return parser
def create_network(opt, network_name, mode, verbose=True):
if network_name is None:
return None
net_cls = find_network_using_name(network_name, mode)
net = net_cls(opt)
if verbose:
net.print_architecture(verbose=True)
return net