Spaces:
Runtime error
Runtime error
from modeling.ifrnet import IFRNet, Discriminator, PatchDiscriminator, MLP | |
from modeling.benchmark import UNet | |
def build_model(args): | |
if args.MODEL.NAME.lower() == "ifrnet": | |
net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS) | |
mlp = MLP(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, num_class=args.MODEL.NUM_CLASS) | |
elif args.MODEL.NAME.lower() == "ifr-no-aux": | |
net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS) | |
mlp = None | |
else: | |
raise NotImplementedError | |
return net, mlp | |
def build_discriminators(args): | |
return Discriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS), PatchDiscriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS) | |