|
import logging |
|
import math |
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.utils import save_image |
|
|
|
from models.archs.fcn_arch import MultiHeadFCNHead |
|
from models.archs.unet_arch import UNet |
|
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder, |
|
VectorQuantizerSpatialTextureAware, |
|
VectorQuantizerTexture) |
|
from models.losses.accuracy import accuracy |
|
from models.losses.cross_entropy_loss import CrossEntropyLoss |
|
|
|
logger = logging.getLogger('base') |
|
|
|
|
|
class VQGANTextureAwareSpatialHierarchyInferenceModel(): |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
self.device = torch.device('cuda') |
|
self.is_train = opt['is_train'] |
|
|
|
self.top_encoder = Encoder( |
|
ch=opt['top_ch'], |
|
num_res_blocks=opt['top_num_res_blocks'], |
|
attn_resolutions=opt['top_attn_resolutions'], |
|
ch_mult=opt['top_ch_mult'], |
|
in_channels=opt['top_in_channels'], |
|
resolution=opt['top_resolution'], |
|
z_channels=opt['top_z_channels'], |
|
double_z=opt['top_double_z'], |
|
dropout=opt['top_dropout']).to(self.device) |
|
self.decoder = Decoder( |
|
in_channels=opt['top_in_channels'], |
|
resolution=opt['top_resolution'], |
|
z_channels=opt['top_z_channels'], |
|
ch=opt['top_ch'], |
|
out_ch=opt['top_out_ch'], |
|
num_res_blocks=opt['top_num_res_blocks'], |
|
attn_resolutions=opt['top_attn_resolutions'], |
|
ch_mult=opt['top_ch_mult'], |
|
dropout=opt['top_dropout'], |
|
resamp_with_conv=True, |
|
give_pre_end=False).to(self.device) |
|
self.top_quantize = VectorQuantizerTexture( |
|
1024, opt['embed_dim'], beta=0.25).to(self.device) |
|
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"], |
|
opt['embed_dim'], |
|
1).to(self.device) |
|
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], |
|
opt["top_z_channels"], |
|
1).to(self.device) |
|
self.load_top_pretrain_models() |
|
|
|
self.bot_encoder = Encoder( |
|
ch=opt['bot_ch'], |
|
num_res_blocks=opt['bot_num_res_blocks'], |
|
attn_resolutions=opt['bot_attn_resolutions'], |
|
ch_mult=opt['bot_ch_mult'], |
|
in_channels=opt['bot_in_channels'], |
|
resolution=opt['bot_resolution'], |
|
z_channels=opt['bot_z_channels'], |
|
double_z=opt['bot_double_z'], |
|
dropout=opt['bot_dropout']).to(self.device) |
|
self.bot_decoder_res = DecoderRes( |
|
in_channels=opt['bot_in_channels'], |
|
resolution=opt['bot_resolution'], |
|
z_channels=opt['bot_z_channels'], |
|
ch=opt['bot_ch'], |
|
num_res_blocks=opt['bot_num_res_blocks'], |
|
ch_mult=opt['bot_ch_mult'], |
|
dropout=opt['bot_dropout'], |
|
give_pre_end=False).to(self.device) |
|
self.bot_quantize = VectorQuantizerSpatialTextureAware( |
|
opt['bot_n_embed'], |
|
opt['embed_dim'], |
|
beta=0.25, |
|
spatial_size=opt['codebook_spatial_size']).to(self.device) |
|
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"], |
|
opt['embed_dim'], |
|
1).to(self.device) |
|
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], |
|
opt["bot_z_channels"], |
|
1).to(self.device) |
|
|
|
self.load_bot_pretrain_network() |
|
|
|
self.guidance_encoder = UNet( |
|
in_channels=opt['encoder_in_channels']).to(self.device) |
|
self.index_decoder = MultiHeadFCNHead( |
|
in_channels=opt['fc_in_channels'], |
|
in_index=opt['fc_in_index'], |
|
channels=opt['fc_channels'], |
|
num_convs=opt['fc_num_convs'], |
|
concat_input=opt['fc_concat_input'], |
|
dropout_ratio=opt['fc_dropout_ratio'], |
|
num_classes=opt['fc_num_classes'], |
|
align_corners=opt['fc_align_corners'], |
|
num_head=18).to(self.device) |
|
|
|
self.init_training_settings() |
|
|
|
def init_training_settings(self): |
|
optim_params = [] |
|
for v in self.guidance_encoder.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
for v in self.index_decoder.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
|
|
if self.opt['optimizer'] == 'Adam': |
|
self.optimizer = torch.optim.Adam( |
|
optim_params, |
|
self.opt['lr'], |
|
weight_decay=self.opt['weight_decay']) |
|
elif self.opt['optimizer'] == 'SGD': |
|
self.optimizer = torch.optim.SGD( |
|
optim_params, |
|
self.opt['lr'], |
|
momentum=self.opt['momentum'], |
|
weight_decay=self.opt['weight_decay']) |
|
self.log_dict = OrderedDict() |
|
if self.opt['loss_function'] == 'cross_entropy': |
|
self.loss_func = CrossEntropyLoss().to(self.device) |
|
|
|
def load_top_pretrain_models(self): |
|
|
|
top_vae_checkpoint = torch.load(self.opt['top_vae_path']) |
|
self.top_encoder.load_state_dict( |
|
top_vae_checkpoint['encoder'], strict=True) |
|
self.decoder.load_state_dict( |
|
top_vae_checkpoint['decoder'], strict=True) |
|
self.top_quantize.load_state_dict( |
|
top_vae_checkpoint['quantize'], strict=True) |
|
self.top_quant_conv.load_state_dict( |
|
top_vae_checkpoint['quant_conv'], strict=True) |
|
self.top_post_quant_conv.load_state_dict( |
|
top_vae_checkpoint['post_quant_conv'], strict=True) |
|
self.top_encoder.eval() |
|
self.top_quantize.eval() |
|
self.top_quant_conv.eval() |
|
self.top_post_quant_conv.eval() |
|
|
|
def load_bot_pretrain_network(self): |
|
checkpoint = torch.load(self.opt['bot_vae_path']) |
|
self.bot_encoder.load_state_dict( |
|
checkpoint['bot_encoder'], strict=True) |
|
self.bot_decoder_res.load_state_dict( |
|
checkpoint['bot_decoder_res'], strict=True) |
|
self.decoder.load_state_dict(checkpoint['decoder'], strict=True) |
|
self.bot_quantize.load_state_dict( |
|
checkpoint['bot_quantize'], strict=True) |
|
self.bot_quant_conv.load_state_dict( |
|
checkpoint['bot_quant_conv'], strict=True) |
|
self.bot_post_quant_conv.load_state_dict( |
|
checkpoint['bot_post_quant_conv'], strict=True) |
|
|
|
self.bot_encoder.eval() |
|
self.bot_decoder_res.eval() |
|
self.decoder.eval() |
|
self.bot_quantize.eval() |
|
self.bot_quant_conv.eval() |
|
self.bot_post_quant_conv.eval() |
|
|
|
def top_encode(self, x, mask): |
|
h = self.top_encoder(x) |
|
h = self.top_quant_conv(h) |
|
quant, _, _ = self.top_quantize(h, mask) |
|
quant = self.top_post_quant_conv(quant) |
|
|
|
return quant, quant |
|
|
|
def feed_data(self, data): |
|
self.image = data['image'].to(self.device) |
|
self.texture_mask = data['texture_mask'].float().to(self.device) |
|
self.get_gt_indices() |
|
|
|
self.texture_tokens = F.interpolate( |
|
self.texture_mask, size=(32, 16), |
|
mode='nearest').view(self.image.size(0), -1).long() |
|
|
|
def bot_encode(self, x, mask): |
|
h = self.bot_encoder(x) |
|
h = self.bot_quant_conv(h) |
|
_, _, (_, _, indices_list) = self.bot_quantize(h, mask) |
|
|
|
return indices_list |
|
|
|
def get_gt_indices(self): |
|
self.quant_t, self.feature_t = self.top_encode(self.image, |
|
self.texture_mask) |
|
self.gt_indices_list = self.bot_encode(self.image, self.texture_mask) |
|
|
|
def index_to_image(self, index_bottom_list, texture_mask): |
|
quant_b = self.bot_quantize.get_codebook_entry( |
|
index_bottom_list, texture_mask, |
|
(index_bottom_list[0].size(0), index_bottom_list[0].size(1), |
|
index_bottom_list[0].size(2), |
|
self.opt["bot_z_channels"])) |
|
quant_b = self.bot_post_quant_conv(quant_b) |
|
bot_dec_res = self.bot_decoder_res(quant_b) |
|
|
|
dec = self.decoder(self.quant_t, bot_h=bot_dec_res) |
|
|
|
return dec |
|
|
|
def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path): |
|
rec_img = self.index_to_image(rec_img_index, texture_mask) |
|
pred_img = self.index_to_image(pred_img_index, texture_mask) |
|
|
|
base_img = self.decoder(self.quant_t) |
|
img_cat = torch.cat([ |
|
self.image, |
|
rec_img, |
|
base_img, |
|
pred_img, |
|
], dim=3).detach() |
|
img_cat = ((img_cat + 1) / 2) |
|
img_cat = img_cat.clamp_(0, 1) |
|
save_image(img_cat, save_path, nrow=1, padding=4) |
|
|
|
def optimize_parameters(self): |
|
self.guidance_encoder.train() |
|
self.index_decoder.train() |
|
|
|
self.feature_enc = self.guidance_encoder(self.feature_t) |
|
self.memory_logits_list = self.index_decoder(self.feature_enc) |
|
|
|
loss = 0 |
|
for i in range(18): |
|
loss += self.loss_func( |
|
self.memory_logits_list[i], |
|
self.gt_indices_list[i], |
|
ignore_index=-1) |
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
self.log_dict['loss_total'] = loss |
|
|
|
def inference(self, data_loader, save_dir): |
|
self.guidance_encoder.eval() |
|
self.index_decoder.eval() |
|
|
|
acc = 0 |
|
num = 0 |
|
|
|
for _, data in enumerate(data_loader): |
|
self.feed_data(data) |
|
img_name = data['img_name'] |
|
|
|
num += self.image.size(0) |
|
|
|
texture_mask_flatten = self.texture_tokens.view(-1) |
|
min_encodings_indices_list = [ |
|
torch.full( |
|
texture_mask_flatten.size(), |
|
fill_value=-1, |
|
dtype=torch.long, |
|
device=texture_mask_flatten.device) for _ in range(18) |
|
] |
|
with torch.no_grad(): |
|
self.feature_enc = self.guidance_encoder(self.feature_t) |
|
memory_logits_list = self.index_decoder(self.feature_enc) |
|
|
|
batch_acc = 0 |
|
for codebook_idx, memory_logits in enumerate(memory_logits_list): |
|
region_of_interest = texture_mask_flatten == codebook_idx |
|
if torch.sum(region_of_interest) > 0: |
|
memory_indices_pred = memory_logits.argmax(dim=1).view(-1) |
|
batch_acc += torch.sum( |
|
memory_indices_pred[region_of_interest] == |
|
self.gt_indices_list[codebook_idx].view( |
|
-1)[region_of_interest]) |
|
memory_indices_pred = memory_indices_pred |
|
min_encodings_indices_list[codebook_idx][ |
|
region_of_interest] = memory_indices_pred[ |
|
region_of_interest] |
|
min_encodings_indices_return_list = [ |
|
min_encodings_indices.view(self.gt_indices_list[0].size()) |
|
for min_encodings_indices in min_encodings_indices_list |
|
] |
|
batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel( |
|
) * self.image.size(0) |
|
acc += batch_acc |
|
self.get_vis(min_encodings_indices_return_list, |
|
self.gt_indices_list, self.texture_mask, |
|
f'{save_dir}/{img_name[0]}') |
|
|
|
self.guidance_encoder.train() |
|
self.index_decoder.train() |
|
return (acc / num).item() |
|
|
|
def load_network(self): |
|
checkpoint = torch.load(self.opt['pretrained_models']) |
|
self.guidance_encoder.load_state_dict( |
|
checkpoint['guidance_encoder'], strict=True) |
|
self.guidance_encoder.eval() |
|
|
|
self.index_decoder.load_state_dict( |
|
checkpoint['index_decoder'], strict=True) |
|
self.index_decoder.eval() |
|
|
|
def save_network(self, save_path): |
|
"""Save networks. |
|
|
|
Args: |
|
net (nn.Module): Network to be saved. |
|
net_label (str): Network label. |
|
current_iter (int): Current iter number. |
|
""" |
|
|
|
save_dict = {} |
|
save_dict['guidance_encoder'] = self.guidance_encoder.state_dict() |
|
save_dict['index_decoder'] = self.index_decoder.state_dict() |
|
|
|
torch.save(save_dict, save_path) |
|
|
|
def update_learning_rate(self, epoch): |
|
"""Update learning rate. |
|
|
|
Args: |
|
current_iter (int): Current iteration. |
|
warmup_iter (int): Warmup iter numbers. -1 for no warmup. |
|
Default: -1. |
|
""" |
|
lr = self.optimizer.param_groups[0]['lr'] |
|
|
|
if self.opt['lr_decay'] == 'step': |
|
lr = self.opt['lr'] * ( |
|
self.opt['gamma']**(epoch // self.opt['step'])) |
|
elif self.opt['lr_decay'] == 'cos': |
|
lr = self.opt['lr'] * ( |
|
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 |
|
elif self.opt['lr_decay'] == 'linear': |
|
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) |
|
elif self.opt['lr_decay'] == 'linear2exp': |
|
if epoch < self.opt['turning_point'] + 1: |
|
|
|
|
|
lr = self.opt['lr'] * ( |
|
1 - epoch / int(self.opt['turning_point'] * 1.0526)) |
|
else: |
|
lr *= self.opt['gamma'] |
|
elif self.opt['lr_decay'] == 'schedule': |
|
if epoch in self.opt['schedule']: |
|
lr *= self.opt['gamma'] |
|
else: |
|
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) |
|
|
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
return lr |
|
|
|
def get_current_log(self): |
|
return self.log_dict |
|
|