tfwang's picture
Update glide_text2im/adv.py
ccf29e8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from .nn import mean_flat
#from . import dist_util
import functools
class AdversarialLoss(nn.Module):
def __init__(self, gan_type='WGAN_GP', gan_k=1,
lr_dis=1e-5 ):
super(AdversarialLoss, self).__init__()
self.gan_type = gan_type
self.gan_k = gan_k
model = NLayerDiscriminator().cuda()
self.discriminator = DDP(
model,
device_ids=[torch.device('cuda')],
output_device=torch.device('cuda'),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
if (gan_type in ['WGAN_GP', 'GAN']):
self.optimizer = optim.Adam(
self.discriminator.parameters(),
lr=lr_dis
)
def forward(self, fake, real):
fake_detach = fake.detach()
for _ in range(self.gan_k):
self.optimizer.zero_grad()
d_fake = self.discriminator(fake_detach)
d_real = self.discriminator(real)
if (self.gan_type.find('WGAN') >= 0):
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand(real.size(0), 1, 1, 1).cuda()
epsilon = epsilon.expand(real.size())
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.discriminator(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# print('d loss:', loss_d)
# Discriminator update
loss_d.backward()
self.optimizer.step()
d_fake_for_g = self.discriminator(fake)
if (self.gan_type.find('WGAN') >= 0):
loss_g = -d_fake_for_g
# Generator loss
return mean_flat(loss_g)
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=True)
def conv7x7(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=7,
stride=stride, padding=3, bias=True)
class Discriminator(nn.Module):
def __init__(self, ):
super(Discriminator, self).__init__()
self.conv1 = conv7x7(3, 32)
self.norm1 = nn.InstanceNorm2d(32, affine=True)
self.LReLU1 = nn.LeakyReLU(0.2)
self.conv2 = conv3x3(32, 32, 2)
self.norm2 = nn.InstanceNorm2d(32, affine=True)
self.LReLU2 = nn.LeakyReLU(0.2)
self.conv3 = conv3x3(32, 64)
self.norm3 = nn.InstanceNorm2d(64, affine=True)
self.LReLU3 = nn.LeakyReLU(0.2)
self.conv4 = conv3x3(64, 64, 2)
self.norm4 = nn.InstanceNorm2d(64, affine=True)
self.LReLU4 = nn.LeakyReLU(0.2)
self.conv5 = conv3x3(64, 128)
self.norm5 = nn.InstanceNorm2d(128, affine=True)
self.LReLU5 = nn.LeakyReLU(0.2)
self.conv6 = conv3x3(128, 128, 2)
self.norm6 = nn.InstanceNorm2d(128, affine=True)
self.LReLU6 = nn.LeakyReLU(0.2)
self.conv7 = conv3x3(128, 256)
self.norm7 = nn.InstanceNorm2d(256, affine=True)
self.LReLU7 = nn.LeakyReLU(0.2)
self.conv8 = conv3x3(256, 256, 2)
self.norm8 = nn.InstanceNorm2d(256, affine=True)
self.LReLU8 = nn.LeakyReLU(0.2)
self.conv9 = conv3x3(256, 512)
self.norm9 = nn.InstanceNorm2d(512, affine=True)
self.LReLU9 = nn.LeakyReLU(0.2)
self.conv10 = conv3x3(512, 512, 2)
self.norm10 = nn.InstanceNorm2d(512, affine=True)
self.LReLU10 = nn.LeakyReLU(0.2)
self.conv11 = conv3x3(512, 32)
self.norm11 = nn.InstanceNorm2d(32, affine=True)
self.LReLU11 = nn.LeakyReLU(0.2)
self.conv12 = conv3x3(32, 1)
def forward(self, x):
x = self.LReLU1(self.norm1(self.conv1(x)))
x = self.LReLU2(self.norm2(self.conv2(x)))
x = self.LReLU3(self.norm3(self.conv3(x)))
x = self.LReLU4(self.norm4(self.conv4(x)))
x = self.LReLU5(self.norm5(self.conv5(x)))
x = self.LReLU6(self.norm6(self.conv6(x)))
x = self.LReLU7(self.norm7(self.conv7(x)))
x = self.LReLU8(self.norm8(self.conv8(x)))
x = self.LReLU9(self.norm9(self.conv9(x)))
x = self.LReLU10(self.norm10(self.conv10(x)))
x = self.LReLU11(self.norm11(self.conv11(x)))
x = self.conv12(x)
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
def norm_layer(x): return Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc=3, ndf=64, n_layers=3 ):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
norm_layer = get_norm_layer(norm_type='instance')
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)