brendenc commited on
Commit
af8d508
1 Parent(s): 1430035

Upload load_model.py

Browse files
Files changed (1) hide show
  1. load_model.py +121 -0
load_model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ from tensorflow.keras import regularizers
4
+ import numpy as np
5
+ import tensorflow_probability as tfp
6
+
7
+ #Affine Coupling Layer
8
+ ## Creating a custom layer with keras API.
9
+ output_dim = 256
10
+ reg = 0.01
11
+
12
+ def Coupling(input_shape):
13
+ input = keras.layers.Input(shape=input_shape)
14
+
15
+ t_layer_1 = keras.layers.Dense(
16
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
17
+ )(input)
18
+ t_layer_2 = keras.layers.Dense(
19
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
20
+ )(t_layer_1)
21
+ t_layer_3 = keras.layers.Dense(
22
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
23
+ )(t_layer_2)
24
+ t_layer_4 = keras.layers.Dense(
25
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
26
+ )(t_layer_3)
27
+ t_layer_5 = keras.layers.Dense(
28
+ input_shape, activation="linear", kernel_regularizer=regularizers.l2(reg)
29
+ )(t_layer_4)
30
+
31
+ s_layer_1 = keras.layers.Dense(
32
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
33
+ )(input)
34
+ s_layer_2 = keras.layers.Dense(
35
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
36
+ )(s_layer_1)
37
+ s_layer_3 = keras.layers.Dense(
38
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
39
+ )(s_layer_2)
40
+ s_layer_4 = keras.layers.Dense(
41
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
42
+ )(s_layer_3)
43
+ s_layer_5 = keras.layers.Dense(
44
+ input_shape, activation="tanh", kernel_regularizer=regularizers.l2(reg)
45
+ )(s_layer_4)
46
+
47
+ return keras.Model(inputs=input, outputs=[s_layer_5, t_layer_5])
48
+
49
+ #Real NVP
50
+ class RealNVP(keras.Model):
51
+ def __init__(self, num_coupling_layers):
52
+ super(RealNVP, self).__init__()
53
+
54
+ self.num_coupling_layers = num_coupling_layers
55
+
56
+ # Distribution of the latent space.
57
+ self.distribution = tfp.distributions.MultivariateNormalDiag(
58
+ loc=[0.0, 0.0], scale_diag=[1.0, 1.0]
59
+ )
60
+ self.masks = np.array(
61
+ [[0, 1], [1, 0]] * (num_coupling_layers // 2), dtype="float32"
62
+ )
63
+ self.loss_tracker = keras.metrics.Mean(name="loss")
64
+ self.layers_list = [Coupling(2) for i in range(num_coupling_layers)]
65
+
66
+ @property
67
+ def metrics(self):
68
+ """List of the model's metrics.
69
+ We make sure the loss tracker is listed as part of `model.metrics`
70
+ so that `fit()` and `evaluate()` are able to `reset()` the loss tracker
71
+ at the start of each epoch and at the start of an `evaluate()` call.
72
+ """
73
+ return [self.loss_tracker]
74
+
75
+ def call(self, x, training=True):
76
+ log_det_inv = 0
77
+ direction = 1
78
+ if training:
79
+ direction = -1
80
+ for i in range(self.num_coupling_layers)[::direction]:
81
+ x_masked = x * self.masks[i]
82
+ reversed_mask = 1 - self.masks[i]
83
+ s, t = self.layers_list[i](x_masked)
84
+ s *= reversed_mask
85
+ t *= reversed_mask
86
+ gate = (direction - 1) / 2
87
+ x = (
88
+ reversed_mask
89
+ * (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s))
90
+ + x_masked
91
+ )
92
+ log_det_inv += gate * tf.reduce_sum(s, [1])
93
+
94
+ return x, log_det_inv
95
+
96
+ # Log likelihood of the normal distribution plus the log determinant of the jacobian.
97
+
98
+ def log_loss(self, x):
99
+ y, logdet = self(x)
100
+ log_likelihood = self.distribution.log_prob(y) + logdet
101
+ return -tf.reduce_mean(log_likelihood)
102
+
103
+ def train_step(self, data):
104
+ with tf.GradientTape() as tape:
105
+
106
+ loss = self.log_loss(data)
107
+
108
+ g = tape.gradient(loss, self.trainable_variables)
109
+ self.optimizer.apply_gradients(zip(g, self.trainable_variables))
110
+ self.loss_tracker.update_state(loss)
111
+
112
+ return {"loss": self.loss_tracker.result()}
113
+
114
+ def test_step(self, data):
115
+ loss = self.log_loss(data)
116
+ self.loss_tracker.update_state(loss)
117
+
118
+ return {"loss": self.loss_tracker.result()}
119
+
120
+ def load_model():
121
+ return RealNVP(num_coupling_layers=6)