mix-bt / ssl-sota /tf2 /whitening.py
wgcban's picture
Upload 98 files
803ef9e
raw
history blame
No virus
1.5 kB
import tensorflow.compat.v2 as tf
from absl import flags
FLAGS = flags.FLAGS
class Whitening1D(tf.keras.layers.Layer):
def __init__(self, eps=0, **kwargs):
super(Whitening1D, self).__init__(**kwargs)
self.eps = eps
def call(self, x):
bs, c = x.shape
x_t = tf.transpose(x, (1, 0))
m = tf.reduce_mean(x_t, axis=1, keepdims=True)
f = x_t - m
ff_apr = tf.matmul(f, f, transpose_b=True) / (tf.cast(bs, tf.float32) - 1.0)
ff_apr_shrinked = (1 - self.eps) * ff_apr + tf.eye(c) * self.eps
sqrt = tf.linalg.cholesky(ff_apr_shrinked)
inv_sqrt = tf.linalg.triangular_solve(sqrt, tf.eye(c))
f_hat = tf.matmul(inv_sqrt, f)
decorelated = tf.transpose(f_hat, (1, 0))
return decorelated
def w_mse_loss(x):
""" input x shape = (batch size * num_samples, proj_out_dim) """
w = Whitening1D()
num_samples = FLAGS.num_samples
num_slice = num_samples * FLAGS.train_batch_size // (2 * FLAGS.proj_out_dim)
x_split = tf.split(x, num_slice, 0)
for i in range(num_slice):
x_split[i] = w(x_split[i])
x = tf.concat(x_split, 0)
x = tf.math.l2_normalize(x, -1)
x_split = tf.split(x, num_samples, 0)
loss = 0
for i in range(num_samples - 1):
for j in range(i + 1, num_samples):
v = x_split[i] * x_split[j]
loss += 2 - 2 * tf.reduce_mean(tf.reduce_sum(v, -1))
loss /= num_samples * (num_samples - 1) // 2
return loss