NCTC / models /research /ptn /model_rotator.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for pretraining (rotator) as described in PTN paper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from six.moves import xrange
import tensorflow as tf
import input_generator
import losses
import metrics
import utils
from nets import deeprotator_factory
slim = tf.contrib.slim
def _get_data_from_provider(inputs, batch_size, split_name):
"""Returns dictionary of batch input data processed by tf.train.batch."""
images, masks = tf.train.batch(
[inputs['image'], inputs['mask']],
batch_size=batch_size,
num_threads=8,
capacity=8 * batch_size,
name='batching_queues/%s' % (split_name))
outputs = dict()
outputs['images'] = images
outputs['masks'] = masks
outputs['num_samples'] = inputs['num_samples']
return outputs
def get_inputs(dataset_dir, dataset_name, split_name, batch_size, image_size,
is_training):
"""Loads the given dataset and split."""
del image_size # Unused
with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)):
common_queue_min = 50
common_queue_capacity = 256
num_readers = 4
inputs = input_generator.get(
dataset_dir,
dataset_name,
split_name,
shuffle=is_training,
num_readers=num_readers,
common_queue_min=common_queue_min,
common_queue_capacity=common_queue_capacity)
return _get_data_from_provider(inputs, batch_size, split_name)
def preprocess(raw_inputs, step_size):
"""Selects the subset of viewpoints to train on."""
shp = raw_inputs['images'].get_shape().as_list()
quantity = shp[0]
num_views = shp[1]
image_size = shp[2]
del image_size # Unused
batch_rot = np.zeros((quantity, 3), dtype=np.float32)
inputs = dict()
for n in xrange(step_size + 1):
inputs['images_%d' % n] = []
inputs['masks_%d' % n] = []
for n in xrange(quantity):
view_in = np.random.randint(0, num_views)
rng_rot = np.random.randint(0, 2)
if step_size == 1:
rng_rot = np.random.randint(0, 3)
delta = 0
if rng_rot == 0:
delta = -1
batch_rot[n, 2] = 1
elif rng_rot == 1:
delta = 1
batch_rot[n, 0] = 1
else:
delta = 0
batch_rot[n, 1] = 1
inputs['images_0'].append(raw_inputs['images'][n, view_in, :, :, :])
inputs['masks_0'].append(raw_inputs['masks'][n, view_in, :, :, :])
view_out = view_in
for k in xrange(1, step_size + 1):
view_out += delta
if view_out >= num_views:
view_out = 0
if view_out < 0:
view_out = num_views - 1
inputs['images_%d' % k].append(raw_inputs['images'][n, view_out, :, :, :])
inputs['masks_%d' % k].append(raw_inputs['masks'][n, view_out, :, :, :])
for n in xrange(step_size + 1):
inputs['images_%d' % n] = tf.stack(inputs['images_%d' % n])
inputs['masks_%d' % n] = tf.stack(inputs['masks_%d' % n])
inputs['actions'] = tf.constant(batch_rot, dtype=tf.float32)
return inputs
def get_init_fn(scopes, params):
"""Initialization assignment operator function used while training."""
if not params.init_model:
return None
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
params.init_model, var_list)
def init_assign_function(sess):
sess.run(init_assign_op, init_feed_dict)
return init_assign_function
def get_model_fn(params, is_training, reuse=False):
return deeprotator_factory.get(params, is_training, reuse)
def get_regularization_loss(scopes, params):
return losses.regularization_loss(scopes, params)
def get_loss(inputs, outputs, params):
"""Computes the rotator loss."""
g_loss = tf.zeros(dtype=tf.float32, shape=[])
if hasattr(params, 'image_weight'):
g_loss += losses.add_rotator_image_loss(inputs, outputs, params.step_size,
params.image_weight)
if hasattr(params, 'mask_weight'):
g_loss += losses.add_rotator_mask_loss(inputs, outputs, params.step_size,
params.mask_weight)
slim.summaries.add_scalar_summary(
g_loss, 'rotator_loss', prefix='losses')
return g_loss
def get_train_op_for_scope(loss, optimizer, scopes, params):
"""Train operation function for the given scope used file training."""
is_trainable = lambda x: x in tf.trainable_variables()
var_list = []
update_ops = []
for scope in scopes:
var_list.extend(
filter(is_trainable, tf.contrib.framework.get_model_variables(scope)))
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
return slim.learning.create_train_op(
loss,
optimizer,
update_ops=update_ops,
variables_to_train=var_list,
clip_gradient_norm=params.clip_gradient_norm)
def get_metrics(inputs, outputs, params):
"""Aggregate the metrics for rotator model.
Args:
inputs: Input dictionary of the rotator model.
outputs: Output dictionary returned by the rotator model.
params: Hyperparameters of the rotator model.
Returns:
names_to_values: metrics->values (dict).
names_to_updates: metrics->ops (dict).
"""
names_to_values = dict()
names_to_updates = dict()
tmp_values, tmp_updates = metrics.add_image_pred_metrics(
inputs, outputs, params.num_views, 3*params.image_size**2)
names_to_values.update(tmp_values)
names_to_updates.update(tmp_updates)
tmp_values, tmp_updates = metrics.add_mask_pred_metrics(
inputs, outputs, params.num_views, params.image_size**2)
names_to_values.update(tmp_values)
names_to_updates.update(tmp_updates)
for name, value in names_to_values.iteritems():
slim.summaries.add_scalar_summary(
value, name, prefix='eval', print_summary=True)
return names_to_values, names_to_updates
def write_disk_grid(global_step, summary_freq, log_dir, input_images,
output_images, pred_images, pred_masks):
"""Function called by TF to save the prediction periodically."""
def write_grid(grid, global_step):
"""Native python function to call for writing images to files."""
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
return 0
grid = _build_image_grid(input_images, output_images, pred_images, pred_masks)
slim.summaries.add_image_summary(
tf.expand_dims(grid, axis=0), name='grid_vis')
save_op = tf.py_func(write_grid, [grid, global_step], [tf.int64],
'write_grid')[0]
return save_op
def _build_image_grid(input_images, output_images, pred_images, pred_masks):
"""Builds a grid image by concatenating the input images."""
quantity = input_images.get_shape().as_list()[0]
for row in xrange(int(quantity / 4)):
for col in xrange(4):
index = row * 4 + col
input_img_ = input_images[index, :, :, :]
output_img_ = output_images[index, :, :, :]
pred_img_ = pred_images[index, :, :, :]
pred_mask_ = tf.tile(pred_masks[index, :, :, :], [1, 1, 3])
if col == 0:
tmp_ = tf.concat([input_img_, output_img_, pred_img_, pred_mask_],
1) ## to the right
else:
tmp_ = tf.concat([tmp_, input_img_, output_img_, pred_img_, pred_mask_],
1)
if row == 0:
out_grid = tmp_
else:
out_grid = tf.concat([out_grid, tmp_], 0)
return out_grid