Text2Human / Text2Human /models /vqgan_model.py
yumingj's picture
update
bde71cb
raw
history blame
No virus
20 kB
import math
import sys
from collections import OrderedDict
sys.path.append('..')
import lpips
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder,
VectorQuantizer, VectorQuantizerTexture)
from models.losses.segmentation_loss import BCELossWithQuant
from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
calculate_adaptive_weight, hinge_d_loss)
class VQModel():
def __init__(self, opt):
super().__init__()
self.opt = opt
self.device = torch.device('cuda')
self.encoder = Encoder(
ch=opt['ch'],
num_res_blocks=opt['num_res_blocks'],
attn_resolutions=opt['attn_resolutions'],
ch_mult=opt['ch_mult'],
in_channels=opt['in_channels'],
resolution=opt['resolution'],
z_channels=opt['z_channels'],
double_z=opt['double_z'],
dropout=opt['dropout']).to(self.device)
self.decoder = Decoder(
in_channels=opt['in_channels'],
resolution=opt['resolution'],
z_channels=opt['z_channels'],
ch=opt['ch'],
out_ch=opt['out_ch'],
num_res_blocks=opt['num_res_blocks'],
attn_resolutions=opt['attn_resolutions'],
ch_mult=opt['ch_mult'],
dropout=opt['dropout'],
resamp_with_conv=True,
give_pre_end=False).to(self.device)
self.quantize = VectorQuantizer(
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
1).to(self.device)
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
opt["z_channels"],
1).to(self.device)
def init_training_settings(self):
self.loss = BCELossWithQuant()
self.log_dict = OrderedDict()
self.configure_optimizers()
def save_network(self, save_path):
"""Save networks.
Args:
net (nn.Module): Network to be saved.
net_label (str): Network label.
current_iter (int): Current iter number.
"""
save_dict = {}
save_dict['encoder'] = self.encoder.state_dict()
save_dict['decoder'] = self.decoder.state_dict()
save_dict['quantize'] = self.quantize.state_dict()
save_dict['quant_conv'] = self.quant_conv.state_dict()
save_dict['post_quant_conv'] = self.post_quant_conv.state_dict()
save_dict['discriminator'] = self.disc.state_dict()
torch.save(save_dict, save_path)
def load_network(self):
checkpoint = torch.load(self.opt['pretrained_models'])
self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
self.quantize.load_state_dict(checkpoint['quantize'], strict=True)
self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True)
self.post_quant_conv.load_state_dict(
checkpoint['post_quant_conv'], strict=True)
def optimize_parameters(self, data, current_iter):
self.encoder.train()
self.decoder.train()
self.quantize.train()
self.quant_conv.train()
self.post_quant_conv.train()
loss = self.training_step(data)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward_step(self, input):
quant, diff, _ = self.encode(input)
dec = self.decode(quant)
return dec, diff
def feed_data(self, data):
x = data['segm']
x = F.one_hot(x, num_classes=self.opt['num_segm_classes'])
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
return x.float().to(self.device)
def get_current_log(self):
return self.log_dict
def update_learning_rate(self, epoch):
"""Update learning rate.
Args:
current_iter (int): Current iteration.
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
Default: -1.
"""
lr = self.optimizer.param_groups[0]['lr']
if self.opt['lr_decay'] == 'step':
lr = self.opt['lr'] * (
self.opt['gamma']**(epoch // self.opt['step']))
elif self.opt['lr_decay'] == 'cos':
lr = self.opt['lr'] * (
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
elif self.opt['lr_decay'] == 'linear':
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
elif self.opt['lr_decay'] == 'linear2exp':
if epoch < self.opt['turning_point'] + 1:
# learning rate decay as 95%
# at the turning point (1 / 95% = 1.0526)
lr = self.opt['lr'] * (
1 - epoch / int(self.opt['turning_point'] * 1.0526))
else:
lr *= self.opt['gamma']
elif self.opt['lr_decay'] == 'schedule':
if epoch in self.opt['schedule']:
lr *= self.opt['gamma']
else:
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
# set learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
class VQSegmentationModel(VQModel):
def __init__(self, opt):
super().__init__(opt)
self.colorize = torch.randn(3, opt['num_segm_classes'], 1,
1).to(self.device)
self.init_training_settings()
def configure_optimizers(self):
self.optimizer = torch.optim.Adam(
list(self.encoder.parameters()) + list(self.decoder.parameters()) +
list(self.quantize.parameters()) +
list(self.quant_conv.parameters()) +
list(self.post_quant_conv.parameters()),
lr=self.opt['lr'],
betas=(0.5, 0.9))
def training_step(self, data):
x = self.feed_data(data)
xrec, qloss = self.forward_step(x)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
self.log_dict.update(log_dict_ae)
return aeloss
def to_rgb(self, x):
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
@torch.no_grad()
def inference(self, data_loader, save_dir):
self.encoder.eval()
self.decoder.eval()
self.quantize.eval()
self.quant_conv.eval()
self.post_quant_conv.eval()
loss_total = 0
loss_bce = 0
loss_quant = 0
num = 0
for _, data in enumerate(data_loader):
img_name = data['img_name'][0]
x = self.feed_data(data)
xrec, qloss = self.forward_step(x)
_, log_dict_ae = self.loss(qloss, x, xrec, split="val")
loss_total += log_dict_ae['val/total_loss']
loss_bce += log_dict_ae['val/bce_loss']
loss_quant += log_dict_ae['val/quant_loss']
num += x.size(0)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
# convert logits to indices
xrec = torch.argmax(xrec, dim=1, keepdim=True)
xrec = F.one_hot(xrec, num_classes=x.shape[1])
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
img_cat = torch.cat([x, xrec], dim=3).detach()
img_cat = ((img_cat + 1) / 2)
img_cat = img_cat.clamp_(0, 1)
save_image(
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
return (loss_total / num).item(), (loss_bce /
num).item(), (loss_quant /
num).item()
class VQImageModel(VQModel):
def __init__(self, opt):
super().__init__(opt)
self.disc = Discriminator(
opt['n_channels'], opt['ndf'],
n_layers=opt['disc_layers']).to(self.device)
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
self.perceptual_weight = opt['perceptual_weight']
self.disc_start_step = opt['disc_start_step']
self.disc_weight_max = opt['disc_weight_max']
self.diff_aug = opt['diff_aug']
self.policy = "color,translation"
self.disc.train()
self.init_training_settings()
def feed_data(self, data):
x = data['image']
return x.float().to(self.device)
def init_training_settings(self):
self.log_dict = OrderedDict()
self.configure_optimizers()
def configure_optimizers(self):
self.optimizer = torch.optim.Adam(
list(self.encoder.parameters()) + list(self.decoder.parameters()) +
list(self.quantize.parameters()) +
list(self.quant_conv.parameters()) +
list(self.post_quant_conv.parameters()),
lr=self.opt['lr'])
self.disc_optimizer = torch.optim.Adam(
self.disc.parameters(), lr=self.opt['lr'])
def training_step(self, data, step):
x = self.feed_data(data)
xrec, codebook_loss = self.forward_step(x)
# get recon/perceptual loss
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
nll_loss = recon_loss + self.perceptual_weight * p_loss
nll_loss = torch.mean(nll_loss)
# augment for input to discriminator
if self.diff_aug:
xrec = DiffAugment(xrec, policy=self.policy)
# update generator
logits_fake = self.disc(xrec)
g_loss = -torch.mean(logits_fake)
last_layer = self.decoder.conv_out.weight
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
self.disc_weight_max)
d_weight *= adopt_weight(1, step, self.disc_start_step)
loss = nll_loss + d_weight * g_loss + codebook_loss
self.log_dict["loss"] = loss
self.log_dict["l1"] = recon_loss.mean().item()
self.log_dict["perceptual"] = p_loss.mean().item()
self.log_dict["nll_loss"] = nll_loss.item()
self.log_dict["g_loss"] = g_loss.item()
self.log_dict["d_weight"] = d_weight
self.log_dict["codebook_loss"] = codebook_loss.item()
if step > self.disc_start_step:
if self.diff_aug:
logits_real = self.disc(
DiffAugment(x.contiguous().detach(), policy=self.policy))
else:
logits_real = self.disc(x.contiguous().detach())
logits_fake = self.disc(xrec.contiguous().detach(
)) # detach so that generator isn"t also updated
d_loss = hinge_d_loss(logits_real, logits_fake)
self.log_dict["d_loss"] = d_loss
else:
d_loss = None
return loss, d_loss
def optimize_parameters(self, data, step):
self.encoder.train()
self.decoder.train()
self.quantize.train()
self.quant_conv.train()
self.post_quant_conv.train()
loss, d_loss = self.training_step(data, step)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if step > self.disc_start_step:
self.disc_optimizer.zero_grad()
d_loss.backward()
self.disc_optimizer.step()
@torch.no_grad()
def inference(self, data_loader, save_dir):
self.encoder.eval()
self.decoder.eval()
self.quantize.eval()
self.quant_conv.eval()
self.post_quant_conv.eval()
loss_total = 0
num = 0
for _, data in enumerate(data_loader):
img_name = data['img_name'][0]
x = self.feed_data(data)
xrec, _ = self.forward_step(x)
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
nll_loss = recon_loss + self.perceptual_weight * p_loss
nll_loss = torch.mean(nll_loss)
loss_total += nll_loss
num += x.size(0)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
# convert logits to indices
xrec = torch.argmax(xrec, dim=1, keepdim=True)
xrec = F.one_hot(xrec, num_classes=x.shape[1])
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
img_cat = torch.cat([x, xrec], dim=3).detach()
img_cat = ((img_cat + 1) / 2)
img_cat = img_cat.clamp_(0, 1)
save_image(
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
return (loss_total / num).item()
class VQImageSegmTextureModel(VQImageModel):
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda')
self.encoder = Encoder(
ch=opt['ch'],
num_res_blocks=opt['num_res_blocks'],
attn_resolutions=opt['attn_resolutions'],
ch_mult=opt['ch_mult'],
in_channels=opt['in_channels'],
resolution=opt['resolution'],
z_channels=opt['z_channels'],
double_z=opt['double_z'],
dropout=opt['dropout']).to(self.device)
self.decoder = Decoder(
in_channels=opt['in_channels'],
resolution=opt['resolution'],
z_channels=opt['z_channels'],
ch=opt['ch'],
out_ch=opt['out_ch'],
num_res_blocks=opt['num_res_blocks'],
attn_resolutions=opt['attn_resolutions'],
ch_mult=opt['ch_mult'],
dropout=opt['dropout'],
resamp_with_conv=True,
give_pre_end=False).to(self.device)
self.quantize = VectorQuantizerTexture(
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
1).to(self.device)
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
opt["z_channels"],
1).to(self.device)
self.disc = Discriminator(
opt['n_channels'], opt['ndf'],
n_layers=opt['disc_layers']).to(self.device)
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
self.perceptual_weight = opt['perceptual_weight']
self.disc_start_step = opt['disc_start_step']
self.disc_weight_max = opt['disc_weight_max']
self.diff_aug = opt['diff_aug']
self.policy = "color,translation"
self.disc.train()
self.init_training_settings()
def feed_data(self, data):
x = data['image'].float().to(self.device)
mask = data['texture_mask'].float().to(self.device)
return x, mask
def training_step(self, data, step):
x, mask = self.feed_data(data)
xrec, codebook_loss = self.forward_step(x, mask)
# get recon/perceptual loss
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
nll_loss = recon_loss + self.perceptual_weight * p_loss
nll_loss = torch.mean(nll_loss)
# augment for input to discriminator
if self.diff_aug:
xrec = DiffAugment(xrec, policy=self.policy)
# update generator
logits_fake = self.disc(xrec)
g_loss = -torch.mean(logits_fake)
last_layer = self.decoder.conv_out.weight
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
self.disc_weight_max)
d_weight *= adopt_weight(1, step, self.disc_start_step)
loss = nll_loss + d_weight * g_loss + codebook_loss
self.log_dict["loss"] = loss
self.log_dict["l1"] = recon_loss.mean().item()
self.log_dict["perceptual"] = p_loss.mean().item()
self.log_dict["nll_loss"] = nll_loss.item()
self.log_dict["g_loss"] = g_loss.item()
self.log_dict["d_weight"] = d_weight
self.log_dict["codebook_loss"] = codebook_loss.item()
if step > self.disc_start_step:
if self.diff_aug:
logits_real = self.disc(
DiffAugment(x.contiguous().detach(), policy=self.policy))
else:
logits_real = self.disc(x.contiguous().detach())
logits_fake = self.disc(xrec.contiguous().detach(
)) # detach so that generator isn"t also updated
d_loss = hinge_d_loss(logits_real, logits_fake)
self.log_dict["d_loss"] = d_loss
else:
d_loss = None
return loss, d_loss
@torch.no_grad()
def inference(self, data_loader, save_dir):
self.encoder.eval()
self.decoder.eval()
self.quantize.eval()
self.quant_conv.eval()
self.post_quant_conv.eval()
loss_total = 0
num = 0
for _, data in enumerate(data_loader):
img_name = data['img_name'][0]
x, mask = self.feed_data(data)
xrec, _ = self.forward_step(x, mask)
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
nll_loss = recon_loss + self.perceptual_weight * p_loss
nll_loss = torch.mean(nll_loss)
loss_total += nll_loss
num += x.size(0)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
# convert logits to indices
xrec = torch.argmax(xrec, dim=1, keepdim=True)
xrec = F.one_hot(xrec, num_classes=x.shape[1])
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
img_cat = torch.cat([x, xrec], dim=3).detach()
img_cat = ((img_cat + 1) / 2)
img_cat = img_cat.clamp_(0, 1)
save_image(
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
return (loss_total / num).item()
def encode(self, x, mask):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h, mask)
return quant, emb_loss, info
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward_step(self, input, mask):
quant, diff, _ = self.encode(input, mask)
dec = self.decode(quant)
return dec, diff