audio-driven-animations / MakeItTalk /thirdparty /face_of_art /deep_heatmaps_model_fusion_net.py
marlenezw's picture
more changes to the third party lib.
dd38ad1
raw
history blame
No virus
50 kB
import scipy.io
import scipy.misc
from glob import glob
import os
import numpy as np
from thirdparty.face_of_art.ops import *
import tensorflow as tf
from tensorflow import contrib
from thirdparty.face_of_art.menpo_functions import *
from thirdparty.face_of_art.logging_functions import *
from thirdparty.face_of_art.data_loading_functions import *
class DeepHeatmapsModel(object):
"""facial landmark localization Network"""
def __init__(self, mode='TRAIN', train_iter=100000, batch_size=10, learning_rate=1e-3, l_weight_primary=1.,
l_weight_fusion=1.,l_weight_upsample=3.,adam_optimizer=True,momentum=0.95,step=100000, gamma=0.1,reg=0,
weight_initializer='xavier', weight_initializer_std=0.01, bias_initializer=0.0, image_size=256,c_dim=3,
num_landmarks=68, sigma=1.5, scale=1, margin=0.25, bb_type='gt', win_mult=3.33335,
augment_basic=True,augment_texture=False, p_texture=0., augment_geom=False, p_geom=0.,
output_dir='output', save_model_path='model',
save_sample_path='sample', save_log_path='logs', test_model_path='model/deep_heatmaps-50000',
pre_train_path='model/deep_heatmaps-50000', load_pretrain=False, load_primary_only=False,
img_path='data', test_data='full', valid_data='full', valid_size=0, log_valid_every=5,
train_crop_dir='crop_gt_margin_0.25', img_dir_ns='crop_gt_margin_0.25_ns',
print_every=100, save_every=5000, sample_every=5000, sample_grid=9, sample_to_log=True,
debug_data_size=20, debug=False, epoch_data_dir='epoch_data', use_epoch_data=False, menpo_verbose=True):
# define some extra parameters
self.log_histograms = False # save weight + gradient histogram to log
self.save_valid_images = True # sample heat maps of validation images
self.sample_per_channel = False # sample heatmaps separately for each landmark
# for fine-tuning, choose reset_training_op==True. when resuming training, reset_training_op==False
self.reset_training_op = False
self.fast_img_gen = True
self.compute_nme = True # compute normalized mean error
self.config = tf.ConfigProto()
self.config.gpu_options.allow_growth = True
# sampling and logging parameters
self.print_every = print_every # print losses to screen + log
self.save_every = save_every # save model
self.sample_every = sample_every # save images of gen heat maps compared to GT
self.sample_grid = sample_grid # number of training images in sample
self.sample_to_log = sample_to_log # sample images to log instead of disk
self.log_valid_every = log_valid_every # log validation loss (in epochs)
self.debug = debug
self.debug_data_size = debug_data_size
self.use_epoch_data = use_epoch_data
self.epoch_data_dir = epoch_data_dir
self.load_pretrain = load_pretrain
self.load_primary_only = load_primary_only
self.pre_train_path = pre_train_path
self.mode = mode
self.train_iter = train_iter
self.learning_rate = learning_rate
self.image_size = image_size
self.c_dim = c_dim
self.batch_size = batch_size
self.num_landmarks = num_landmarks
self.save_log_path = save_log_path
self.save_sample_path = save_sample_path
self.save_model_path = save_model_path
self.test_model_path = test_model_path
self.img_path=img_path
self.momentum = momentum
self.step = step # for lr decay
self.gamma = gamma # for lr decay
self.reg = reg # weight decay scale
self.l_weight_primary = l_weight_primary # primary loss weight
self.l_weight_fusion = l_weight_fusion # fusion loss weight
self.l_weight_upsample = l_weight_upsample # upsample loss weight
self.weight_initializer = weight_initializer # random_normal or xavier
self.weight_initializer_std = weight_initializer_std
self.bias_initializer = bias_initializer
self.adam_optimizer = adam_optimizer
self.sigma = sigma # sigma for heatmap generation
self.scale = scale # scale for image normalization 255 / 1 / 0
self.win_mult = win_mult # gaussian filter size for cpu/gpu approximation: 2 * sigma * win_mult + 1
self.test_data = test_data # if mode is TEST, this choose the set to use full/common/challenging/test/art
self.train_crop_dir = train_crop_dir
self.img_dir_ns = os.path.join(img_path,img_dir_ns)
self.augment_basic = augment_basic # perform basic augmentation (rotation,flip,crop)
self.augment_texture = augment_texture # perform artistic texture augmentation (NS)
self.p_texture = p_texture # initial probability of artistic texture augmentation
self.augment_geom = augment_geom # perform artistic geometric augmentation
self.p_geom = p_geom # initial probability of artistic geometric augmentation
self.valid_size = valid_size
self.valid_data = valid_data
# load image, bb and landmark data using menpo
self.bb_dir = os.path.join(img_path, 'Bounding_Boxes')
self.bb_dictionary = load_bb_dictionary(self.bb_dir, mode, test_data=self.test_data)
# use pre-augmented data, to save time during training
if self.use_epoch_data:
epoch_0 = os.path.join(self.epoch_data_dir, '0')
self.img_menpo_list = load_menpo_image_list(
img_path, train_crop_dir=epoch_0, img_dir_ns=None, mode=mode, bb_dictionary=self.bb_dictionary,
image_size=self.image_size, test_data=self.test_data, augment_basic=False, augment_texture=False,
augment_geom=False, verbose=menpo_verbose)
else:
self.img_menpo_list = load_menpo_image_list(
img_path, train_crop_dir, self.img_dir_ns, mode, bb_dictionary=self.bb_dictionary,
image_size=self.image_size, margin=margin, bb_type=bb_type, test_data=self.test_data,
augment_basic=augment_basic, augment_texture=augment_texture, p_texture=p_texture,
augment_geom=augment_geom, p_geom=p_geom, verbose=menpo_verbose)
if mode == 'TRAIN':
train_params = locals()
print_training_params_to_file(train_params) # save init parameters
self.train_inds = np.arange(len(self.img_menpo_list))
if self.debug:
self.train_inds = self.train_inds[:self.debug_data_size]
self.img_menpo_list = self.img_menpo_list[self.train_inds]
if valid_size > 0:
self.valid_bb_dictionary = load_bb_dictionary(self.bb_dir, 'TEST', test_data=self.valid_data)
self.valid_img_menpo_list = load_menpo_image_list(
img_path, train_crop_dir, self.img_dir_ns, 'TEST', bb_dictionary=self.valid_bb_dictionary,
image_size=self.image_size, margin=margin, bb_type=bb_type, test_data=self.valid_data,
verbose=menpo_verbose)
np.random.seed(0)
self.val_inds = np.arange(len(self.valid_img_menpo_list))
np.random.shuffle(self.val_inds)
self.val_inds = self.val_inds[:self.valid_size]
self.valid_img_menpo_list = self.valid_img_menpo_list[self.val_inds]
self.valid_images_loaded =\
np.zeros([self.valid_size, self.image_size, self.image_size, self.c_dim]).astype('float32')
self.valid_gt_maps_small_loaded =\
np.zeros([self.valid_size, self.image_size / 4, self.image_size / 4,
self.num_landmarks]).astype('float32')
self.valid_gt_maps_loaded =\
np.zeros([self.valid_size, self.image_size, self.image_size, self.num_landmarks]
).astype('float32')
self.valid_landmarks_loaded = np.zeros([self.valid_size, num_landmarks, 2]).astype('float32')
self.valid_landmarks_pred = np.zeros([self.valid_size, self.num_landmarks, 2]).astype('float32')
load_images_landmarks_approx_maps_alloc_once(
self.valid_img_menpo_list, np.arange(self.valid_size), images=self.valid_images_loaded,
maps_small=self.valid_gt_maps_small_loaded, maps=self.valid_gt_maps_loaded,
landmarks=self.valid_landmarks_loaded, image_size=self.image_size,
num_landmarks=self.num_landmarks, scale=self.scale, win_mult=self.win_mult, sigma=self.sigma,
save_landmarks=self.compute_nme)
if self.valid_size > self.sample_grid:
self.valid_gt_maps_loaded = self.valid_gt_maps_loaded[:self.sample_grid]
self.valid_gt_maps_small_loaded = self.valid_gt_maps_small_loaded[:self.sample_grid]
else:
self.val_inds = None
self.epoch_inds_shuffle = train_val_shuffle_inds_per_epoch(
self.val_inds, self.train_inds, train_iter, batch_size, save_log_path)
def add_placeholders(self):
if self.mode == 'TEST':
self.images = tf.placeholder(
tf.float32, [None, self.image_size, self.image_size, self.c_dim], 'images')
self.heatmaps = tf.placeholder(
tf.float32, [None, self.image_size, self.image_size, self.num_landmarks], 'heatmaps')
self.heatmaps_small = tf.placeholder(
tf.float32, [None, int(self.image_size/4), int(self.image_size/4), self.num_landmarks], 'heatmaps_small')
self.lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'lms')
self.pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'pred_lms')
elif self.mode == 'TRAIN':
self.images = tf.placeholder(
tf.float32, [None, self.image_size, self.image_size, self.c_dim], 'train_images')
self.heatmaps = tf.placeholder(
tf.float32, [None, self.image_size, self.image_size, self.num_landmarks], 'train_heatmaps')
self.heatmaps_small = tf.placeholder(
tf.float32, [None, int(self.image_size/4), int(self.image_size/4), self.num_landmarks], 'train_heatmaps_small')
self.train_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'train_lms')
self.train_pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'train_pred_lms')
self.valid_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'valid_lms')
self.valid_pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'valid_pred_lms')
# self.p_texture_log = tf.placeholder(tf.float32, [])
# self.p_geom_log = tf.placeholder(tf.float32, [])
# self.sparse_hm_small = tf.placeholder(tf.float32, [None, int(self.image_size/4), int(self.image_size/4), 1])
# self.sparse_hm = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 1])
if self.sample_to_log:
row = int(np.sqrt(self.sample_grid))
self.log_image_map_small = tf.placeholder(
tf.uint8, [None, row * int(self.image_size/4), 3 * row * int(self.image_size/4), self.c_dim],
'sample_img_map_small')
self.log_image_map = tf.placeholder(
tf.uint8, [None, row * self.image_size, 3 * row * self.image_size, self.c_dim],
'sample_img_map')
if self.sample_per_channel:
row = np.ceil(np.sqrt(self.num_landmarks)).astype(np.int64)
self.log_map_channels_small = tf.placeholder(
tf.uint8, [None, row * int(self.image_size/4), 2 * row * int(self.image_size/4), self.c_dim],
'sample_map_channels_small')
self.log_map_channels = tf.placeholder(
tf.uint8, [None, row * self.image_size, 2 * row * self.image_size, self.c_dim],
'sample_map_channels')
def heatmaps_network(self, input_images, reuse=None, name='pred_heatmaps'):
with tf.name_scope(name):
if self.weight_initializer == 'xavier':
weight_initializer = contrib.layers.xavier_initializer()
else:
weight_initializer = tf.random_normal_initializer(stddev=self.weight_initializer_std)
bias_init = tf.constant_initializer(self.bias_initializer)
with tf.variable_scope('heatmaps_network'):
with tf.name_scope('primary_net'):
l1 = conv_relu_pool(input_images, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
reuse=reuse, var_scope='conv_1')
l2 = conv_relu_pool(l1, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
reuse=reuse, var_scope='conv_2')
l3 = conv_relu(l2, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
reuse=reuse, var_scope='conv_3')
l4_1 = conv_relu(l3, 3, 128, conv_dilation=1, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_1')
l4_2 = conv_relu(l3, 3, 128, conv_dilation=2, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_2')
l4_3 = conv_relu(l3, 3, 128, conv_dilation=3, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_3')
l4_4 = conv_relu(l3, 3, 128, conv_dilation=4, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_4')
l4 = tf.concat([l4_1, l4_2, l4_3, l4_4], 3, name='conv_4')
l5_1 = conv_relu(l4, 3, 256, conv_dilation=1, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_1')
l5_2 = conv_relu(l4, 3, 256, conv_dilation=2, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_2')
l5_3 = conv_relu(l4, 3, 256, conv_dilation=3, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_3')
l5_4 = conv_relu(l4, 3, 256, conv_dilation=4, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_4')
l5 = tf.concat([l5_1, l5_2, l5_3, l5_4], 3, name='conv_5')
l6 = conv_relu(l5, 1, 512, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_6')
l7 = conv_relu(l6, 1, 256, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_7')
primary_out = conv(l7, 1, self.num_landmarks, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_8')
with tf.name_scope('fusion_net'):
l_fsn_0 = tf.concat([l3, l7], 3, name='conv_3_7_fsn')
l_fsn_1_1 = conv_relu(l_fsn_0, 3, 64, conv_dilation=1, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_1')
l_fsn_1_2 = conv_relu(l_fsn_0, 3, 64, conv_dilation=2, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_2')
l_fsn_1_3 = conv_relu(l_fsn_0, 3, 64, conv_dilation=3, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_3')
l_fsn_1 = tf.concat([l_fsn_1_1, l_fsn_1_2, l_fsn_1_3], 3, name='conv_fsn_1')
l_fsn_2_1 = conv_relu(l_fsn_1, 3, 64, conv_dilation=1, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_1')
l_fsn_2_2 = conv_relu(l_fsn_1, 3, 64, conv_dilation=2, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_2')
l_fsn_2_3 = conv_relu(l_fsn_1, 3, 64, conv_dilation=4, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_3')
l_fsn_2_4 = conv_relu(l_fsn_1, 5, 64, conv_dilation=3, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_4')
l_fsn_2 = tf.concat([l_fsn_2_1, l_fsn_2_2, l_fsn_2_3, l_fsn_2_4], 3, name='conv_fsn_2')
l_fsn_3_1 = conv_relu(l_fsn_2, 3, 128, conv_dilation=1, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_1')
l_fsn_3_2 = conv_relu(l_fsn_2, 3, 128, conv_dilation=2, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_2')
l_fsn_3_3 = conv_relu(l_fsn_2, 3, 128, conv_dilation=4, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_3')
l_fsn_3_4 = conv_relu(l_fsn_2, 5, 128, conv_dilation=3, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_4')
l_fsn_3 = tf.concat([l_fsn_3_1, l_fsn_3_2, l_fsn_3_3, l_fsn_3_4], 3, name='conv_fsn_3')
l_fsn_4 = conv_relu(l_fsn_3, 1, 256, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_4')
fusion_out = conv(l_fsn_4, 1, self.num_landmarks, conv_ker_init=weight_initializer,
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_5')
with tf.name_scope('upsample_net'):
out = deconv(fusion_out, 8, self.num_landmarks, conv_stride=4,
conv_ker_init=deconv2d_bilinear_upsampling_initializer(
[8, 8, self.num_landmarks, self.num_landmarks]), conv_bias_init=bias_init,
reuse=reuse, var_scope='deconv_1')
self.all_layers = [l1, l2, l3, l4, l5, l6, l7, primary_out, l_fsn_1, l_fsn_2, l_fsn_3, l_fsn_4,
fusion_out, out]
return primary_out, fusion_out, out
def build_model(self):
self.pred_hm_p, self.pred_hm_f, self.pred_hm_u = self.heatmaps_network(self.images,name='heatmaps_prediction')
def create_loss_ops(self):
def nme_norm_eyes(pred_landmarks, real_landmarks, normalize=True, name='NME'):
"""calculate normalized mean error on landmarks - normalize with inter pupil distance"""
with tf.name_scope(name):
with tf.name_scope('real_pred_landmarks_rmse'):
# calculate RMS ERROR between GT and predicted lms
landmarks_rms_err = tf.reduce_mean(
tf.sqrt(tf.reduce_sum(tf.square(pred_landmarks - real_landmarks), axis=2)), axis=1)
if normalize:
# normalize RMS ERROR with inter-pupil distance of GT lms
with tf.name_scope('inter_pupil_dist'):
with tf.name_scope('left_eye_center'):
p1 = tf.reduce_mean(tf.slice(real_landmarks, [0, 42, 0], [-1, 6, 2]), axis=1)
with tf.name_scope('right_eye_center'):
p2 = tf.reduce_mean(tf.slice(real_landmarks, [0, 36, 0], [-1, 6, 2]), axis=1)
eye_dist = tf.sqrt(tf.reduce_sum(tf.square(p1 - p2), axis=1))
return landmarks_rms_err / eye_dist
else:
return landmarks_rms_err
if self.mode is 'TRAIN':
# calculate L2 loss between ideal and predicted heatmaps
primary_maps_diff = self.pred_hm_p - self.heatmaps_small
fusion_maps_diff = self.pred_hm_f - self.heatmaps_small
upsample_maps_diff = self.pred_hm_u - self.heatmaps
self.l2_primary = tf.reduce_mean(tf.square(primary_maps_diff))
self.l2_fusion = tf.reduce_mean(tf.square(fusion_maps_diff))
self.l2_upsample = tf.reduce_mean(tf.square(upsample_maps_diff))
self.total_loss = 1000.*(self.l_weight_primary * self.l2_primary + self.l_weight_fusion * self.l2_fusion +
self.l_weight_upsample * self.l2_upsample)
# add weight decay
self.total_loss += self.reg * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name])
# compute normalized mean error on gt vs. predicted landmarks (for validation)
if self.compute_nme:
self.nme_loss = tf.reduce_mean(nme_norm_eyes(self.train_pred_lms, self.train_lms))
if self.valid_size > 0 and self.compute_nme:
self.valid_nme_loss = tf.reduce_mean(nme_norm_eyes(self.valid_pred_lms, self.valid_lms))
elif self.mode == 'TEST' and self.compute_nme:
self.nme_per_image = nme_norm_eyes(self.pred_lms, self.lms)
self.nme_loss = tf.reduce_mean(self.nme_per_image)
def predict_valid_landmarks_in_batches(self, images, session):
num_images=int(images.shape[0])
num_batches = int(1.*num_images/self.batch_size)
if num_batches == 0:
batch_size = num_images
num_batches = 1
else:
batch_size = self.batch_size
for j in range(num_batches):
batch_images = images[j * batch_size:(j + 1) * batch_size,:,:,:]
batch_maps_pred = session.run(self.pred_hm_u, {self.images: batch_images})
batch_heat_maps_to_landmarks_alloc_once(
batch_maps=batch_maps_pred, batch_landmarks=self.valid_landmarks_pred[j * batch_size:(j + 1) * batch_size, :, :],
batch_size=batch_size,image_size=self.image_size,num_landmarks=self.num_landmarks)
reminder = num_images-num_batches*batch_size
if reminder > 0:
batch_images = images[-reminder:, :, :, :]
batch_maps_pred = session.run(self.pred_hm_u, {self.images: batch_images})
batch_heat_maps_to_landmarks_alloc_once(
batch_maps=batch_maps_pred,
batch_landmarks=self.valid_landmarks_pred[-reminder:, :, :],
batch_size=reminder, image_size=self.image_size, num_landmarks=self.num_landmarks)
def create_summary_ops(self):
"""create summary ops for logging"""
# loss summary
l2_primary = tf.summary.scalar('l2_primary', self.l2_primary)
l2_fusion = tf.summary.scalar('l2_fusion', self.l2_fusion)
l2_upsample = tf.summary.scalar('l2_upsample', self.l2_upsample)
l_total = tf.summary.scalar('l_total', self.total_loss)
self.batch_summary_op = tf.summary.merge([l2_primary,l2_fusion,l2_upsample,l_total])
if self.compute_nme:
nme = tf.summary.scalar('nme', self.nme_loss)
self.batch_summary_op = tf.summary.merge([self.batch_summary_op, nme])
if self.log_histograms:
var_summary = [tf.summary.histogram(var.name,var) for var in tf.trainable_variables()]
grads = tf.gradients(self.total_loss, tf.trainable_variables())
grads = list(zip(grads, tf.trainable_variables()))
grad_summary = [tf.summary.histogram(var.name+'/grads',grad) for grad,var in grads]
activ_summary = [tf.summary.histogram(layer.name, layer) for layer in self.all_layers]
self.batch_summary_op = tf.summary.merge([self.batch_summary_op, var_summary, grad_summary, activ_summary])
if self.valid_size > 0 and self.compute_nme:
self.valid_summary = tf.summary.scalar('valid_nme', self.valid_nme_loss)
if self.sample_to_log:
img_map_summary_small = tf.summary.image('compare_map_to_gt_small', self.log_image_map_small)
img_map_summary = tf.summary.image('compare_map_to_gt', self.log_image_map)
if self.sample_per_channel:
map_channels_summary = tf.summary.image('compare_map_channels_to_gt', self.log_map_channels)
map_channels_summary_small = tf.summary.image('compare_map_channels_to_gt_small',
self.log_map_channels_small)
self.img_summary = tf.summary.merge(
[img_map_summary, img_map_summary_small,map_channels_summary,map_channels_summary_small])
else:
self.img_summary = tf.summary.merge([img_map_summary, img_map_summary_small])
if self.valid_size >= self.sample_grid:
img_map_summary_valid_small = tf.summary.image('compare_map_to_gt_small_valid', self.log_image_map_small)
img_map_summary_valid = tf.summary.image('compare_map_to_gt_valid', self.log_image_map)
if self.sample_per_channel:
map_channels_summary_valid_small = tf.summary.image('compare_map_channels_to_gt_small_valid',
self.log_map_channels_small)
map_channels_summary_valid = tf.summary.image('compare_map_channels_to_gt_valid',
self.log_map_channels)
self.img_summary_valid = tf.summary.merge(
[img_map_summary_valid,img_map_summary_valid_small,map_channels_summary_valid,
map_channels_summary_valid_small])
else:
self.img_summary_valid = tf.summary.merge([img_map_summary_valid, img_map_summary_valid_small])
def train(self):
# set random seed
tf.set_random_seed(1234)
np.random.seed(1234)
# build a graph
# add placeholders
self.add_placeholders()
# build model
self.build_model()
# create loss ops
self.create_loss_ops()
# create summary ops
self.create_summary_ops()
# create optimizer and training op
global_step = tf.Variable(0, trainable=False)
lr = tf.train.exponential_decay(self.learning_rate,global_step, self.step, self.gamma, staircase=True)
if self.adam_optimizer:
optimizer = tf.train.AdamOptimizer(lr)
else:
optimizer = tf.train.MomentumOptimizer(lr, self.momentum)
train_op = optimizer.minimize(self.total_loss,global_step=global_step)
with tf.Session(config=self.config) as sess:
tf.global_variables_initializer().run()
# load pre trained weights if load_pretrain==True
if self.load_pretrain:
print
print('*** loading pre-trained weights from: '+self.pre_train_path+' ***')
if self.load_primary_only:
print('*** loading primary-net only ***')
primary_var = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if
('deconv_' not in v.name) and ('_fsn_' not in v.name)]
loader = tf.train.Saver(var_list=primary_var)
else:
loader = tf.train.Saver()
loader.restore(sess, self.pre_train_path)
print("*** Model restore finished, current global step: %d" % global_step.eval())
# for fine-tuning, choose reset_training_op==True. when resuming training, reset_training_op==False
if self.reset_training_op:
print ("resetting optimizer and global step")
opt_var_list = [optimizer.get_slot(var, name) for name in optimizer.get_slot_names()
for var in tf.global_variables() if optimizer.get_slot(var, name) is not None]
opt_var_list_init = tf.variables_initializer(opt_var_list)
opt_var_list_init.run()
sess.run(global_step.initializer)
# create model saver and file writer
summary_writer = tf.summary.FileWriter(logdir=self.save_log_path, graph=tf.get_default_graph())
saver = tf.train.Saver()
print('\n*** Start Training ***')
# initialize some variables before training loop
resume_step = global_step.eval()
num_train_images = len(self.img_menpo_list)
batches_in_epoch = int(float(num_train_images) / float(self.batch_size))
epoch = int(resume_step / batches_in_epoch)
img_inds = self.epoch_inds_shuffle[epoch, :]
log_valid = True
log_valid_images = True
# allocate space for batch images, maps and landmarks
batch_images = np.zeros([self.batch_size, self.image_size, self.image_size, self.c_dim]).astype(
'float32')
batch_lms = np.zeros([self.batch_size, self.num_landmarks, 2]).astype('float32')
batch_lms_pred = np.zeros([self.batch_size, self.num_landmarks, 2]).astype('float32')
batch_maps_small = np.zeros((self.batch_size, int(self.image_size/4),
int(self.image_size/4), self.num_landmarks)).astype('float32')
batch_maps = np.zeros((self.batch_size, self.image_size, self.image_size,
self.num_landmarks)).astype('float32')
# create gaussians for heatmap generation
gaussian_filt_large = create_gaussian_filter(sigma=self.sigma, win_mult=self.win_mult)
gaussian_filt_small = create_gaussian_filter(sigma=1.*self.sigma/4, win_mult=self.win_mult)
# training loop
for step in range(resume_step, self.train_iter):
j = step % batches_in_epoch # j==0 if we finished an epoch
# if we finished an epoch and this isn't the first step
if step > resume_step and j == 0:
epoch += 1
img_inds = self.epoch_inds_shuffle[epoch, :] # get next shuffled image inds
log_valid = True
log_valid_images = True
if self.use_epoch_data: # if using pre-augmented data, load epoch directory
epoch_dir = os.path.join(self.epoch_data_dir, str(epoch))
self.img_menpo_list = load_menpo_image_list(
self.img_path, train_crop_dir=epoch_dir, img_dir_ns=None, mode=self.mode,
bb_dictionary=self.bb_dictionary, image_size=self.image_size, test_data=self.test_data,
augment_basic=False, augment_texture=False, augment_geom=False)
# get batch indices
batch_inds = img_inds[j * self.batch_size:(j + 1) * self.batch_size]
# load batch images, gt maps and landmarks
load_images_landmarks_approx_maps_alloc_once(
self.img_menpo_list, batch_inds, images=batch_images, maps_small=batch_maps_small,
maps=batch_maps, landmarks=batch_lms, image_size=self.image_size,
num_landmarks=self.num_landmarks, scale=self.scale, gauss_filt_large=gaussian_filt_large,
gauss_filt_small=gaussian_filt_small, win_mult=self.win_mult, sigma=self.sigma,
save_landmarks=self.compute_nme)
feed_dict_train = {self.images: batch_images, self.heatmaps: batch_maps,
self.heatmaps_small: batch_maps_small}
# train on batch
sess.run(train_op, feed_dict_train)
# save to log and print status
if step == resume_step or (step + 1) % self.print_every == 0:
# train data log
if self.compute_nme:
batch_maps_pred = sess.run(self.pred_hm_u, {self.images: batch_images})
batch_heat_maps_to_landmarks_alloc_once(
batch_maps=batch_maps_pred,batch_landmarks=batch_lms_pred,
batch_size=self.batch_size, image_size=self.image_size,
num_landmarks=self.num_landmarks)
train_feed_dict_log = {
self.images: batch_images, self.heatmaps: batch_maps,
self.heatmaps_small: batch_maps_small, self.train_lms: batch_lms,
self.train_pred_lms: batch_lms_pred}
summary, l_p, l_f, l_t, nme = sess.run(
[self.batch_summary_op, self.l2_primary, self.l2_fusion, self.total_loss,
self.nme_loss],
train_feed_dict_log)
print (
'epoch: [%d] step: [%d/%d] primary loss: [%.6f] fusion loss: [%.6f]'
' total loss: [%.6f] NME: [%.6f]' % (
epoch, step + 1, self.train_iter, l_p, l_f, l_t, nme))
else:
train_feed_dict_log = {self.images: batch_images, self.heatmaps: batch_maps,
self.heatmaps_small: batch_maps_small}
summary, l_p, l_f, l_t = sess.run(
[self.batch_summary_op, self.l2_primary, self.l2_fusion, self.total_loss],
train_feed_dict_log)
print (
'epoch: [%d] step: [%d/%d] primary loss: [%.6f] fusion loss: [%.6f] total loss: [%.6f]'
% (epoch, step + 1, self.train_iter, l_p, l_f, l_t))
summary_writer.add_summary(summary, step)
# valid data log
if self.valid_size > 0 and (log_valid and epoch % self.log_valid_every == 0) \
and self.compute_nme:
log_valid = False
self.predict_valid_landmarks_in_batches(self.valid_images_loaded, sess)
valid_feed_dict_log = {
self.valid_lms: self.valid_landmarks_loaded,
self.valid_pred_lms: self.valid_landmarks_pred}
v_summary, v_nme = sess.run([self.valid_summary, self.valid_nme_loss],
valid_feed_dict_log)
summary_writer.add_summary(v_summary, step)
print (
'epoch: [%d] step: [%d/%d] valid NME: [%.6f]' % (
epoch, step + 1, self.train_iter, v_nme))
# save model
if (step + 1) % self.save_every == 0:
saver.save(sess, os.path.join(self.save_model_path, 'deep_heatmaps'), global_step=step + 1)
print ('model/deep-heatmaps-%d saved' % (step + 1))
# save images
if step == resume_step or (step + 1) % self.sample_every == 0:
batch_maps_small_pred = sess.run(self.pred_hm_p, {self.images: batch_images})
if not self.compute_nme:
batch_maps_pred = sess.run(self.pred_hm_u, {self.images: batch_images})
batch_lms_pred = None
merged_img = merge_images_landmarks_maps_gt(
batch_images.copy(), batch_maps_pred, batch_maps, landmarks=batch_lms_pred,
image_size=self.image_size, num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
scale=self.scale, circle_size=2, fast=self.fast_img_gen)
merged_img_small = merge_images_landmarks_maps_gt(
batch_images.copy(), batch_maps_small_pred, batch_maps_small,
image_size=self.image_size,
num_landmarks=self.num_landmarks, num_samples=self.sample_grid, scale=self.scale,
circle_size=0, fast=self.fast_img_gen)
if self.sample_per_channel:
map_per_channel = map_comapre_channels(
batch_images.copy(), batch_maps_pred, batch_maps, image_size=self.image_size,
num_landmarks=self.num_landmarks, scale=self.scale)
map_per_channel_small = map_comapre_channels(
batch_images.copy(), batch_maps_small_pred, batch_maps_small, image_size=int(self.image_size/4),
num_landmarks=self.num_landmarks, scale=self.scale)
if self.sample_to_log: # save heatmap images to log
if self.sample_per_channel:
summary_img = sess.run(
self.img_summary, {self.log_image_map: np.expand_dims(merged_img, 0),
self.log_map_channels: np.expand_dims(map_per_channel, 0),
self.log_image_map_small: np.expand_dims(merged_img_small, 0),
self.log_map_channels_small: np.expand_dims(map_per_channel_small, 0)})
else:
summary_img = sess.run(
self.img_summary, {self.log_image_map: np.expand_dims(merged_img, 0),
self.log_image_map_small: np.expand_dims(merged_img_small, 0)})
summary_writer.add_summary(summary_img, step)
if (self.valid_size >= self.sample_grid) and self.save_valid_images and\
(log_valid_images and epoch % self.log_valid_every == 0):
log_valid_images = False
batch_maps_small_pred_val,batch_maps_pred_val =\
sess.run([self.pred_hm_p,self.pred_hm_u],
{self.images: self.valid_images_loaded[:self.sample_grid]})
merged_img_small = merge_images_landmarks_maps_gt(
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_small_pred_val,
self.valid_gt_maps_small_loaded, image_size=self.image_size,
num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
scale=self.scale, circle_size=0, fast=self.fast_img_gen)
merged_img = merge_images_landmarks_maps_gt(
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_pred_val,
self.valid_gt_maps_loaded, image_size=self.image_size,
num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
scale=self.scale, circle_size=2, fast=self.fast_img_gen)
if self.sample_per_channel:
map_per_channel_small = map_comapre_channels(
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_small_pred_val,
self.valid_gt_maps_small_loaded, image_size=int(self.image_size / 4),
num_landmarks=self.num_landmarks, scale=self.scale)
map_per_channel = map_comapre_channels(
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_pred,
self.valid_gt_maps_loaded, image_size=self.image_size,
num_landmarks=self.num_landmarks, scale=self.scale)
summary_img = sess.run(
self.img_summary_valid,
{self.log_image_map: np.expand_dims(merged_img, 0),
self.log_map_channels: np.expand_dims(map_per_channel, 0),
self.log_image_map_small: np.expand_dims(merged_img_small, 0),
self.log_map_channels_small: np.expand_dims(map_per_channel_small, 0)})
else:
summary_img = sess.run(
self.img_summary_valid,
{self.log_image_map: np.expand_dims(merged_img, 0),
self.log_image_map_small: np.expand_dims(merged_img_small, 0)})
summary_writer.add_summary(summary_img, step)
else: # save heatmap images to directory
sample_path_imgs = os.path.join(
self.save_sample_path, 'epoch-%d-train-iter-%d-1.png' % (epoch, step + 1))
sample_path_imgs_small = os.path.join(
self.save_sample_path, 'epoch-%d-train-iter-%d-1-s.png' % (epoch, step + 1))
scipy.misc.imsave(sample_path_imgs, merged_img)
scipy.misc.imsave(sample_path_imgs_small, merged_img_small)
if self.sample_per_channel:
sample_path_ch_maps = os.path.join(
self.save_sample_path, 'epoch-%d-train-iter-%d-3.png' % (epoch, step + 1))
sample_path_ch_maps_small = os.path.join(
self.save_sample_path, 'epoch-%d-train-iter-%d-3-s.png' % (epoch, step + 1))
scipy.misc.imsave(sample_path_ch_maps, map_per_channel)
scipy.misc.imsave(sample_path_ch_maps_small, map_per_channel_small)
print('*** Finished Training ***')
def get_image_maps(self, test_image, reuse=None, norm=False):
""" returns heatmaps of input image (menpo image object)"""
self.add_placeholders()
# build model
pred_hm_p, pred_hm_f, pred_hm_u = self.heatmaps_network(self.images, reuse=reuse)
with tf.Session(config=self.config) as sess:
# load trained parameters
saver = tf.train.Saver()
saver.restore(sess, self.test_model_path)
_, model_name = os.path.split(self.test_model_path)
test_image = test_image.pixels_with_channels_at_back().astype('float32')
if norm:
if self.scale is '255':
test_image *= 255
elif self.scale is '0':
test_image = 2 * test_image - 1
map_primary, map_fusion, map_upsample = sess.run(
[pred_hm_p, pred_hm_f, pred_hm_u], {self.images: np.expand_dims(test_image, 0)})
return map_primary, map_fusion, map_upsample
def get_landmark_predictions(self, img_list, pdm_models_dir, clm_model_path, reuse=None, map_to_input_size=False):
"""returns dictionary with landmark predictions of each step of the ECpTp algorithm and ECT"""
from thirdparty.face_of_art.pdm_clm_functions import feature_based_pdm_corr, clm_correct
jaw_line_inds = np.arange(0, 17)
left_brow_inds = np.arange(17, 22)
right_brow_inds = np.arange(22, 27)
self.add_placeholders()
# build model
_, _, pred_hm_u = self.heatmaps_network(self.images, reuse=reuse)
with tf.Session(config=self.config) as sess:
# load trained parameters
saver = tf.train.Saver()
saver.restore(sess, self.test_model_path)
_, model_name = os.path.split(self.test_model_path)
e_list = []
ect_list = []
ecp_list = []
ecpt_list = []
ecptp_jaw_list = []
ecptp_out_list = []
for test_image in img_list:
if map_to_input_size:
test_image_transform = test_image[1]
test_image=test_image[0]
# get landmarks for estimation stage
if test_image.n_channels < 3:
test_image_map = sess.run(
pred_hm_u, {self.images: np.expand_dims(
gray2rgb(test_image.pixels_with_channels_at_back()).astype('float32'), 0)})
else:
test_image_map = sess.run(
pred_hm_u, {self.images: np.expand_dims(
test_image.pixels_with_channels_at_back().astype('float32'), 0)})
init_lms = heat_maps_to_landmarks(np.squeeze(test_image_map))
# get landmarks for part-based correction stage
p_pdm_lms = feature_based_pdm_corr(lms_init=init_lms, models_dir=pdm_models_dir, train_type='basic')
# get landmarks for part-based tuning stage
try: # clm may not converge
pdm_clm_lms = clm_correct(
clm_model_path=clm_model_path, image=test_image, map=test_image_map, lms_init=p_pdm_lms)
except:
pdm_clm_lms = p_pdm_lms.copy()
# get landmarks ECT
try: # clm may not converge
ect_lms = clm_correct(
clm_model_path=clm_model_path, image=test_image, map=test_image_map, lms_init=init_lms)
except:
ect_lms = p_pdm_lms.copy()
# get landmarks for ECpTp_out (tune jaw and eyebrows)
ecptp_out = p_pdm_lms.copy()
ecptp_out[left_brow_inds] = pdm_clm_lms[left_brow_inds]
ecptp_out[right_brow_inds] = pdm_clm_lms[right_brow_inds]
ecptp_out[jaw_line_inds] = pdm_clm_lms[jaw_line_inds]
# get landmarks for ECpTp_jaw (tune jaw)
ecptp_jaw = p_pdm_lms.copy()
ecptp_jaw[jaw_line_inds] = pdm_clm_lms[jaw_line_inds]
if map_to_input_size:
ecptp_jaw = test_image_transform.apply(ecptp_jaw)
ecptp_out = test_image_transform.apply(ecptp_out)
ect_lms = test_image_transform.apply(ect_lms)
init_lms = test_image_transform.apply(init_lms)
p_pdm_lms = test_image_transform.apply(p_pdm_lms)
pdm_clm_lms = test_image_transform.apply(pdm_clm_lms)
ecptp_jaw_list.append(ecptp_jaw) # E + p-correction + p-tuning (ECpTp_jaw)
ecptp_out_list.append(ecptp_out) # E + p-correction + p-tuning (ECpTp_out)
ect_list.append(ect_lms) # ECT prediction
e_list.append(init_lms) # init prediction from heatmap network (E)
ecp_list.append(p_pdm_lms) # init prediction + part pdm correction (ECp)
ecpt_list.append(pdm_clm_lms) # init prediction + part pdm correction + global tuning (ECpT)
pred_dict = {
'E': e_list,
'ECp': ecp_list,
'ECpT': ecpt_list,
'ECT': ect_list,
'ECpTp_jaw': ecptp_jaw_list,
'ECpTp_out': ecptp_out_list
}
return pred_dict