| | import os |
| | import argparse |
| |
|
| | import torch |
| | import torch.nn |
| | from torch.utils.data import TensorDataset |
| | import torch.backends.cudnn as cudnn |
| |
|
| | class Generator(nn.Module): |
| |
|
| | def __init__(self, in_ch): |
| | super(Generator, self).__init__() |
| | self.conv1 = nn.Conv2d(in_ch, 64, 4, stride=2, padding=1) |
| | self.bn1 = nn.BatchNorm2d(64) |
| | self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1) |
| | self.bn2 = nn.BatchNorm2d(128) |
| | self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1) |
| | self.bn3 = nn.BatchNorm2d(64) |
| | self.deconv4 = nn.ConvTranspose2d(64, in_ch, 4, stride=2, padding=1) |
| |
|
| | def forward(self, x): |
| | h = F.leaky_relu(self.bn1(self.conv1(x))) |
| | h = F.leaky_relu(self.bn2(self.conv2(h))) |
| | h = F.leaky_relu(self.bn3(self.deconv3(h))) |
| | h = F.tanh(self.deconv4(h)) |
| | return h |
| |
|
| | class Discriminator(nn.Module): |
| |
|
| | def __init__(self, in_ch): |
| | super(Discriminator, self).__init__() |
| | self.conv1 = nn.Conv2d(in_ch, 64, 3, stride=2) |
| | self.conv2 = nn.Conv2d(64, 128, 3, stride=2) |
| | self.bn2 = nn.BatchNorm2d(128) |
| | self.conv3 = nn.Conv2d(128, 256, 3, stride=2) |
| | self.bn3 = nn.BatchNorm2d(256) |
| | if in_ch == 1: |
| | self.fc4 = nn.Linear(1024, 1) |
| | else: |
| | self.fc4 = nn.Linear(2304, 1) |
| |
|
| | def forward(self, x): |
| | h = F.leaky_relu(self.conv1(x)) |
| | h = F.leaky_relu(self.bn2(self.conv2(h))) |
| | h = F.leaky_relu(self.bn3(self.conv3(h))) |
| | h = F.sigmoid(self.fc4(h.view(h.size(0), -1))) |
| | return h |
| |
|
| |
|
| | def main(args): |
| |
|
| | |
| | G = Generator(in_ch = C).cuda() |
| | D = Discriminator(in_ch = C).cuda() |
| |
|
| | |
| | opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) |
| | opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) |
| | loss_bce = nn.BCELoss() |
| | loss_mse = nn.MSELoss() |
| | cudnn.benchmark = True |
| |
|
| | |
| | train_data = torch.load("./adv_data.tar") |
| | train_data = TensorDataset(train_data["normal"], train_data["adv"]) |
| | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True) |
| |
|
| | |
| | for i in range(args.epochs): |
| | G.eval() |
| | x_fake = G(x_adv_temp).data |
| | G.train() |
| | gen_loss, dis_loss, n = 0, 0, 0 |
| | for x, x_adv in train_loader: |
| | current_size = x.size(0) |
| | x, x_adv = x.cuda(), x_adv.cuda() |
| |
|
| | |
| | t_real = torch.ones(current_size).cuda() |
| | t_fake = torch.zeros(current_size).cuda() |
| | y_real = D(x).squeeze() |
| | x_fake = G(x_adv) |
| | y_fake = D(x_fake).squeeze() |
| |
|
| | loss_D = loss_bce(y_real, t_real) + loss_bce(y_fake, t_fake) |
| | opt_D.zero_grad() |
| | loss_D.backward() |
| | opt_D.step() |
| |
|
| | |
| | for _ in range(2): |
| | x_fake = G(x_adv) |
| | y_fake = D(x_fake).squeeze() |
| |
|
| | loss_G = args.alpha * loss_mse(x_fake, x) + args.beta * loss_bce(y_fake, t_real) |
| | opt_G.zero_grad() |
| | loss_G.backward() |
| | opt_G.step() |
| |
|
| | gen_loss += loss_D.data[0] * x.size(0) |
| | dis_loss += loss_G.data[0] * x.size(0) |
| | n += x.size(0) |
| |
|
| | print("epoch:{}, LossG:{:.3f}, LossD:{:.3f}".format(i, gen_loss / n, dis_loss / n)) |
| | torch.save({"generator": G.state_dict(), "discriminator": D.state_dict()}, |
| | os.path.join(args.checkpoint, "{}.tar".format(i + 1))) |
| |
|
| | G.eval() |
| |
|
| | def get_args(): |
| | |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument("--data", type=str, default="mnist") |
| | parser.add_argument("--lr", type=float, default=0.0002) |
| | parser.add_argument("--epochs", type=int, default=2) |
| | parser.add_argument("--alpha", type=float, default=0.7) |
| | parser.add_argument("--beta", type=float, default=0.3) |
| | parser.add_argument("--checkpoint", type=str, default="./checkpoint/test") |
| | args = parser.parse_args() |
| | |
| | return args |
| |
|
| |
|
| | if __name__ == "__main__": |
| | get_args() |
| | main(args) |
| |
|