File size: 861 Bytes
4ac4e3b
fafff42
 
4ac4e3b
fafff42
 
 
 
 
 
 
 
 
 
 
 
4ac4e3b
fafff42
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import tensorflow as tf
from tensorflow import keras


class OrthogonalRegularizer(keras.regularizers.Regularizer):
    """Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""

    def __init__(self, num_features, l2reg=0.001):
        self.num_features = num_features
        self.l2reg = l2reg
        self.identity = tf.eye(num_features)

    def __call__(self, x):
        identity = tf.cast(self.identity, x.dtype)
        x = tf.reshape(x, (tf.shape(x)[0], self.num_features, self.num_features))
        xxt = tf.tensordot(x, x, axes=(2, 2))
        xxt = tf.reshape(xxt, (tf.shape(x)[0] * tf.shape(x)[0], self.num_features, self.num_features))
        return tf.reduce_sum(self.l2reg * tf.square(xxt - identity))

    def get_config(self):
        config = {"num_features": self.num_features, "l2reg": self.l2reg}
        return config