File size: 1,081 Bytes
e437acb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from .generator import Encoder, Decoder, Encoder_s, Decoder_s, Cls
from .light_cnn import network_29layers_v2, resblock


# define generator
def define_G(hdim=256, attack_type=4):
    netE_nir = Encoder_s(hdim=hdim)
    netCls = Cls(hdim=hdim, attack_type=attack_type)
    netE_vis = Encoder(hdim=hdim)
    netG = Decoder_s(hdim=hdim)

    netE_nir = torch.nn.DataParallel(netE_nir).cuda()
    netE_vis = torch.nn.DataParallel(netE_vis).cuda()
    netG = torch.nn.DataParallel(netG).cuda()
    netCls = torch.nn.DataParallel(netCls).cuda()

    return netE_nir, netE_vis, netG, netCls


# define identity preserving && feature extraction net
def define_IP(is_train=False):
    netIP = network_29layers_v2(resblock, [1, 2, 3, 4], is_train)
    netIP = torch.nn.DataParallel(netIP).cuda()
    return netIP


# define recognition network
def LightCNN_29v2(num_classes=10000, is_train=True):
    net = network_29layers_v2(resblock, [1, 2, 3, 4], is_train, num_classes=num_classes)
    net = torch.nn.DataParallel(net).cuda()
    return net