Spaces:
Runtime error
Runtime error
import torch | |
import swapae.util as util | |
from swapae.models.networks.base_network import BaseNetwork | |
def find_network_using_name(target_network_name, filename): | |
target_class_name = target_network_name + filename | |
module_name = 'swapae.models.networks.' + filename | |
network = util.find_class_in_module(target_class_name, module_name) | |
assert issubclass(network, BaseNetwork), \ | |
"Class %s should be a subclass of BaseNetwork" % network | |
return network | |
def modify_commandline_options(parser, is_train): | |
opt, _ = parser.parse_known_args() | |
netE_cls = find_network_using_name(opt.netE, 'encoder') | |
assert netE_cls is not None | |
parser = netE_cls.modify_commandline_options(parser, is_train) | |
netG_cls = find_network_using_name(opt.netG, 'generator') | |
assert netG_cls is not None | |
parser = netG_cls.modify_commandline_options(parser, is_train) | |
netD_cls = find_network_using_name(opt.netD, 'discriminator') | |
parser = netD_cls.modify_commandline_options(parser, is_train) | |
if opt.netPatchD is not None: | |
netD_cls = find_network_using_name(opt.netPatchD, 'patch_discriminator') | |
assert netD_cls is not None | |
parser = netD_cls.modify_commandline_options(parser, is_train) | |
return parser | |
def create_network(opt, network_name, mode, verbose=True): | |
if network_name is None: | |
return None | |
net_cls = find_network_using_name(network_name, mode) | |
net = net_cls(opt) | |
if verbose: | |
net.print_architecture(verbose=True) | |
return net | |