Spaces:
Runtime error
Runtime error
File size: 1,533 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import swapae.util as util
from swapae.models.networks.base_network import BaseNetwork
def find_network_using_name(target_network_name, filename):
target_class_name = target_network_name + filename
module_name = 'swapae.models.networks.' + filename
network = util.find_class_in_module(target_class_name, module_name)
assert issubclass(network, BaseNetwork), \
"Class %s should be a subclass of BaseNetwork" % network
return network
def modify_commandline_options(parser, is_train):
opt, _ = parser.parse_known_args()
netE_cls = find_network_using_name(opt.netE, 'encoder')
assert netE_cls is not None
parser = netE_cls.modify_commandline_options(parser, is_train)
netG_cls = find_network_using_name(opt.netG, 'generator')
assert netG_cls is not None
parser = netG_cls.modify_commandline_options(parser, is_train)
netD_cls = find_network_using_name(opt.netD, 'discriminator')
parser = netD_cls.modify_commandline_options(parser, is_train)
if opt.netPatchD is not None:
netD_cls = find_network_using_name(opt.netPatchD, 'patch_discriminator')
assert netD_cls is not None
parser = netD_cls.modify_commandline_options(parser, is_train)
return parser
def create_network(opt, network_name, mode, verbose=True):
if network_name is None:
return None
net_cls = find_network_using_name(network_name, mode)
net = net_cls(opt)
if verbose:
net.print_architecture(verbose=True)
return net
|