Spaces:
Running
Running
# Copyright 2018 Google, Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Closed form linear regression. | |
Can be differentiated through. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import collections | |
import numpy as np | |
import sonnet as snt | |
import tensorflow as tf | |
from learning_unsupervised_learning import utils | |
from learning_unsupervised_learning import variable_replace | |
def solve_ridge(x, y, ridge_factor): | |
with tf.name_scope("solve_ridge"): | |
# Added a column of ones to the end of the feature matrix for bias | |
A = tf.concat([x, tf.ones((x.shape.as_list()[0], 1))], axis=1) | |
# Analytic solution for the ridge regression loss | |
inv_target = tf.matmul(A, A, transpose_a=True) | |
np_diag_penalty = ridge_factor * np.ones( | |
A.shape.as_list()[1], dtype="float32") | |
# Remove penalty on bias component of weights | |
np_diag_penalty[-1] = 0. | |
diag_penalty = tf.constant(np_diag_penalty) | |
inv_target += tf.diag(diag_penalty) | |
inv = tf.matrix_inverse(inv_target) | |
w = tf.matmul(inv, tf.matmul(A, y, transpose_a=True)) | |
return w | |
class LinearRegressionMetaObjective(snt.AbstractModule): | |
"""A meta objective based on training Ridge Regression with analytic solution. | |
This is used to evaluate the performance of a given feature set trained in | |
some other manner. | |
""" | |
def __init__(self, | |
local_device=None, | |
remote_device=None, | |
zero_one_labels=True, | |
normalize_y_hat=True, | |
normalize_act=False, | |
averages=1, | |
ridge_factor=0.1, | |
center_y=True, | |
hinge_loss=False, | |
samples_per_class=10, | |
test_train_scalar=1.0, | |
): | |
self._local_device = local_device | |
self._remote_device = remote_device | |
self.zero_one_labels = zero_one_labels | |
self.normalize_y_hat = normalize_y_hat | |
self.normalize_act = normalize_act | |
self.ridge_factor = ridge_factor | |
self.averages = averages | |
self.samples_per_class = samples_per_class | |
self.center_y=center_y | |
self.test_train_scalar=test_train_scalar | |
self.hinge_loss = hinge_loss | |
self.dataset_map = {} | |
super(LinearRegressionMetaObjective, | |
self).__init__(name="LinearRegressionMetaObjective") | |
def _build(self, dataset, feature_transformer): | |
if self.samples_per_class is not None: | |
if dataset not in self.dataset_map: | |
# datasets are outside of frames from while loops | |
with tf.control_dependencies(None): | |
self.dataset_map[dataset] = utils.sample_n_per_class( | |
dataset, self.samples_per_class) | |
dataset = self.dataset_map[dataset] | |
stats = collections.defaultdict(list) | |
losses = [] | |
# TODO(lmetz) move this to ingraph control flow? | |
for _ in xrange(self.averages): | |
loss, stat = self._build_once(dataset, feature_transformer) | |
losses.append(loss) | |
for k, v in stat.items(): | |
stats[k].append(v) | |
stats = {k: tf.add_n(v) / float(len(v)) for k, v in stats.items()} | |
summary_updates = [] | |
for k, v in stats.items(): | |
tf.summary.scalar(k, v) | |
with tf.control_dependencies(summary_updates): | |
return tf.add_n(losses) / float(len(losses)) | |
def _build_once(self, dataset, feature_transformer): | |
with tf.device(self._local_device): | |
batch = dataset() | |
num_classes = batch.label_onehot.shape.as_list()[1] | |
regression_mod = snt.Linear(num_classes) | |
if self.normalize_act: | |
def normalize_transformer(x): | |
unnorm_x = feature_transformer(x) | |
return tf.nn.l2_normalize(unnorm_x, 0) | |
feature_transformer_wrap = normalize_transformer | |
else: | |
feature_transformer_wrap = feature_transformer | |
# construct the variables of the right shape in the sonnet module by | |
# calling a forward pass through the regressor. | |
with utils.assert_no_new_variables(): | |
dummy_features = feature_transformer_wrap(batch) | |
regression_mod(dummy_features) | |
reg_w = regression_mod.w | |
reg_b = regression_mod.b | |
batch_test = dataset() | |
all_batch = utils.structure_map_multi(lambda x: tf.concat(x, 0), [batch, batch_test]) | |
#all_batch = tf.concat([batch, batch_test], 0) | |
# Grab a new batch of data from the dataset. | |
features = feature_transformer_wrap(all_batch) | |
features, features_test = utils.structure_map_split(lambda x: tf.split(x, 2, axis=0), features) | |
def center_y(y): | |
y -= tf.reduce_mean(y) | |
y *= tf.rsqrt(tf.reduce_mean(tf.reduce_sum(y**2, axis=[1], keep_dims=True))) | |
return y | |
def get_y_vec(batch): | |
y_pieces = [] | |
if hasattr(batch, "label_onehot"): | |
if self.zero_one_labels: | |
y_pieces += [batch.label_onehot] | |
else: | |
y_pieces += [2. * batch.label_onehot - 1.] | |
if hasattr(batch, "regression_target"): | |
y_pieces += [batch.regression_target] | |
y = tf.concat(y_pieces, 1) | |
if self.center_y: | |
y = center_y(y) | |
return y | |
y_train = get_y_vec(batch) | |
w = solve_ridge(features, y_train, self.ridge_factor) | |
# Generate features from another batch to evaluate loss on the validation | |
# set. This provide a less overfit signal to the learned optimizer. | |
y_test = get_y_vec(batch_test) | |
def compute_logit(features): | |
# We have updated the classifier mod in previous steps, we need to | |
# substitute out those variables to get new values. | |
replacement = collections.OrderedDict([(reg_w, w[:-1]), (reg_b, w[-1])]) | |
with variable_replace.variable_replace(replacement): | |
logits = regression_mod(features) | |
return logits | |
batch_size = y_train.shape.as_list()[0] | |
logit_train = compute_logit(features) | |
logit_test_unnorm = compute_logit(features_test) | |
if self.normalize_y_hat: | |
logit_test = logit_test_unnorm / tf.sqrt( | |
tf.reduce_sum(logit_test_unnorm**2, axis=[1], keep_dims=True)) | |
else: | |
logit_test = logit_test_unnorm | |
stats = {} | |
if self.hinge_loss: | |
# slightly closer to the true classification loss | |
# any distance smaller than 1 is guaranteed to map to the correct class | |
mse_test = tf.reduce_sum(tf.nn.relu(tf.reduce_sum(tf.square(logit_test - y_test), axis=1)-1.)) / batch_size | |
else: | |
mse_test = tf.reduce_sum(tf.square(logit_test - y_test)) / batch_size | |
stats["mse_test"] = mse_test | |
mse_train = tf.reduce_sum(tf.square(logit_train - y_train)) / batch_size | |
stats["mse_train"] = mse_train | |
is_correct_test = tf.equal(tf.argmax(logit_test, 1), tf.argmax(y_test, 1)) | |
accuracy_test = tf.reduce_mean(tf.cast(is_correct_test, tf.float32)) | |
stats["accuracy_test"] = accuracy_test | |
def test_confusion_fn(): | |
test_confusion = tf.confusion_matrix(tf.argmax(y_test, 1), tf.argmax(logit_test, 1)) | |
test_confusion = tf.to_float(test_confusion) / tf.constant((logit_test.shape.as_list()[0] / float(logit_test.shape.as_list()[1])), dtype=tf.float32) | |
test_confusion = tf.expand_dims(tf.expand_dims(test_confusion, 0), 3) | |
return test_confusion | |
tf.summary.image("test_confusion", test_confusion_fn()) | |
def train_confusion_fn(): | |
train_confusion = tf.confusion_matrix(tf.argmax(y_train, 1), tf.argmax(logit_train, 1)) | |
train_confusion = tf.to_float(train_confusion) / tf.constant((logit_train.shape.as_list()[0] / float(logit_train.shape.as_list()[1])), dtype=tf.float32) | |
train_confusion = tf.expand_dims(tf.expand_dims(train_confusion, 0), 3) | |
return train_confusion | |
tf.summary.image("train_confusion", train_confusion_fn()) | |
is_correct = tf.equal(tf.argmax(logit_train, 1), tf.argmax(y_train, 1)) | |
accuracy_train = tf.reduce_mean(tf.cast(is_correct, tf.float32)) | |
stats["accuracy_train"] = accuracy_train | |
reg = self.ridge_factor * tf.reduce_sum(tf.square(w[:-1])) / batch_size | |
stats["ridge_component"] = reg | |
stats["total_loss"] = mse_test + reg | |
loss_to_train_at = (reg+ mse_test) * self.test_train_scalar + (mse_train + reg)*(1 - self.test_train_scalar) | |
loss_to_train_at = tf.identity(loss_to_train_at) | |
# Minimizing the test loss should not require regurization because the | |
# metaobjective is solved for the training loss | |
return loss_to_train_at, stats | |
def local_variables(self): | |
"""List of variables that need to be updated for each evaluation. | |
These variables should not be stored on a parameter server and | |
should be reset every computation of a meta_objective loss. | |
Returns: | |
vars: list of tf.Variable | |
""" | |
return list( | |
snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES)) | |
def remote_variables(self): | |
return [] | |