load function
Browse files- load_bnn_model.py +100 -0
load_bnn_model.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow import keras
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
import tensorflow_probability as tfp
|
5 |
+
|
6 |
+
def load_bnn_model():
|
7 |
+
FEATURE_NAMES = [
|
8 |
+
"fixed acidity",
|
9 |
+
"volatile acidity",
|
10 |
+
"citric acid",
|
11 |
+
"residual sugar",
|
12 |
+
"chlorides",
|
13 |
+
"free sulfur dioxide",
|
14 |
+
"total sulfur dioxide",
|
15 |
+
"density",
|
16 |
+
"pH",
|
17 |
+
"sulphates",
|
18 |
+
"alcohol",
|
19 |
+
]
|
20 |
+
|
21 |
+
hidden_units=[8,8]
|
22 |
+
learning_rate = 0.001
|
23 |
+
def create_model_inputs():
|
24 |
+
inputs = {}
|
25 |
+
for feature_name in FEATURE_NAMES:
|
26 |
+
inputs[feature_name] = layers.Input(
|
27 |
+
name=feature_name, shape=(1,), dtype=tf.float32
|
28 |
+
)
|
29 |
+
return inputs
|
30 |
+
|
31 |
+
# Define the prior weight distribution as Normal of mean=0 and stddev=1.
|
32 |
+
# Note that, in this example, the we prior distribution is not trainable,
|
33 |
+
# as we fix its parameters.
|
34 |
+
def prior(kernel_size, bias_size, dtype=None):
|
35 |
+
n = kernel_size + bias_size
|
36 |
+
prior_model = keras.Sequential(
|
37 |
+
[
|
38 |
+
tfp.layers.DistributionLambda(
|
39 |
+
lambda t: tfp.distributions.MultivariateNormalDiag(
|
40 |
+
loc=tf.zeros(n), scale_diag=tf.ones(n)
|
41 |
+
)
|
42 |
+
)
|
43 |
+
]
|
44 |
+
)
|
45 |
+
return prior_model
|
46 |
+
|
47 |
+
|
48 |
+
# Define variational posterior weight distribution as multivariate Gaussian.
|
49 |
+
# Note that the learnable parameters for this distribution are the means,
|
50 |
+
# variances, and covariances.
|
51 |
+
def posterior(kernel_size, bias_size, dtype=None):
|
52 |
+
n = kernel_size + bias_size
|
53 |
+
posterior_model = keras.Sequential(
|
54 |
+
[
|
55 |
+
tfp.layers.VariableLayer(
|
56 |
+
tfp.layers.MultivariateNormalTriL.params_size(n), dtype=dtype
|
57 |
+
),
|
58 |
+
tfp.layers.MultivariateNormalTriL(n),
|
59 |
+
]
|
60 |
+
)
|
61 |
+
return posterior_model
|
62 |
+
|
63 |
+
def create_probablistic_bnn_model(train_size):
|
64 |
+
inputs = create_model_inputs()
|
65 |
+
features = keras.layers.concatenate(list(inputs.values()))
|
66 |
+
features = layers.BatchNormalization()(features)
|
67 |
+
|
68 |
+
# Create hidden layers with weight uncertainty using the DenseVariational layer.
|
69 |
+
for units in hidden_units:
|
70 |
+
features = tfp.layers.DenseVariational(
|
71 |
+
units=units,
|
72 |
+
make_prior_fn=prior,
|
73 |
+
make_posterior_fn=posterior,
|
74 |
+
kl_weight=1 / train_size,
|
75 |
+
activation="sigmoid",
|
76 |
+
)(features)
|
77 |
+
|
78 |
+
# Create a probabilistic output (Normal distribution), and use the `Dense` layer
|
79 |
+
# to produce the parameters of the distribution.
|
80 |
+
# We set units=2 to learn both the mean and the variance of the Normal distribution.
|
81 |
+
distribution_params = layers.Dense(units=2)(features)
|
82 |
+
outputs = tfp.layers.IndependentNormal(1)(distribution_params)
|
83 |
+
|
84 |
+
model = keras.Model(inputs=inputs,
|
85 |
+
outputs=outputs)
|
86 |
+
|
87 |
+
return model
|
88 |
+
|
89 |
+
def negative_loglikelihood(targets, estimated_distribution):
|
90 |
+
estimated_distirbution = tfp.distributions.MultivariateNormalTriL(estimated_distribution)
|
91 |
+
return -estimated_distribution.log_prob(targets)
|
92 |
+
|
93 |
+
model = create_probablistic_bnn_model(4163)
|
94 |
+
model.compile(
|
95 |
+
optimizer=keras.optimizers.RMSprop(learning_rate=learning_rate),
|
96 |
+
loss=negative_loglikelihood,
|
97 |
+
metrics=[keras.metrics.RootMeanSquaredError()],
|
98 |
+
)
|
99 |
+
model.load_weights('bnn_wine_model.h5')
|
100 |
+
return model
|