import tensorflow as tf from tensorflow import keras class OrthogonalRegularizer(keras.regularizers.Regularizer): """Reference: https://keras.io/examples/vision/pointnet/#build-a-model""" def __init__(self, num_features, l2reg=0.001): self.num_features = num_features self.l2reg = l2reg self.identity = tf.eye(num_features) def __call__(self, x): identity = tf.cast(self.identity, x.dtype) x = tf.reshape(x, (tf.shape(x)[0], self.num_features, self.num_features)) xxt = tf.tensordot(x, x, axes=(2, 2)) xxt = tf.reshape(xxt, (tf.shape(x)[0] * tf.shape(x)[0], self.num_features, self.num_features)) return tf.reduce_sum(self.l2reg * tf.square(xxt - identity)) def get_config(self): config = {"num_features": self.num_features, "l2reg": self.l2reg} return config