Spaces:
Runtime error
Runtime error
Add Tensorflow Training Script
Browse files- mnist-tf.py +92 -0
mnist-tf.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Load the MNIST dataset
|
6 |
+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
7 |
+
|
8 |
+
# Normalize and reshape the data
|
9 |
+
x_train, x_test = x_train.astype(np.float32) / 255.0, x_test.astype(np.float32) / 255.0
|
10 |
+
x_train = x_train.reshape(-1, 784)
|
11 |
+
x_test = x_test.reshape(-1, 784)
|
12 |
+
|
13 |
+
# Convert labels to one-hot encoding
|
14 |
+
y_train = tf.one_hot(y_train, depth=10)
|
15 |
+
y_test = tf.one_hot(y_test, depth=10)
|
16 |
+
|
17 |
+
# Create TensorFlow datasets for better performance
|
18 |
+
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).batch(100)
|
19 |
+
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(100)
|
20 |
+
|
21 |
+
# Network parameters
|
22 |
+
input_layer_size = 784
|
23 |
+
hidden_layer_one = 256
|
24 |
+
hidden_layer_two = 256
|
25 |
+
number_classes = 10
|
26 |
+
|
27 |
+
weights = {
|
28 |
+
'w1': tf.Variable(tf.random.normal([input_layer_size, hidden_layer_one], dtype=tf.float32)),
|
29 |
+
'w2': tf.Variable(tf.random.normal([hidden_layer_one, hidden_layer_two], dtype=tf.float32)),
|
30 |
+
'w_out': tf.Variable(tf.random.normal([hidden_layer_two, number_classes], dtype=tf.float32))
|
31 |
+
}
|
32 |
+
|
33 |
+
biases = {
|
34 |
+
'b1': tf.Variable(tf.random.normal([hidden_layer_one], dtype=tf.float32)),
|
35 |
+
'b2': tf.Variable(tf.random.normal([hidden_layer_two], dtype=tf.float32)),
|
36 |
+
'b_out': tf.Variable(tf.random.normal([number_classes], dtype=tf.float32))
|
37 |
+
}
|
38 |
+
|
39 |
+
# Network architecture
|
40 |
+
def feedforward_network(x):
|
41 |
+
layer_1 = tf.nn.relu(tf.add(tf.matmul(x, weights['w1']), biases['b1']))
|
42 |
+
layer_2 = tf.nn.relu(tf.add(tf.matmul(layer_1, weights['w2']), biases['b2']))
|
43 |
+
output_layer = tf.matmul(layer_2, weights['w_out']) + biases['b_out']
|
44 |
+
return output_layer
|
45 |
+
|
46 |
+
# Training hyperparameters
|
47 |
+
epochs = 45
|
48 |
+
learning_rate = 0.001
|
49 |
+
job_dir = 'mnist_model'
|
50 |
+
|
51 |
+
# Training loop
|
52 |
+
for epoch in range(epochs):
|
53 |
+
for step, (batch_x, batch_y) in enumerate(train_dataset):
|
54 |
+
with tf.GradientTape() as tape:
|
55 |
+
logits = feedforward_network(batch_x)
|
56 |
+
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=batch_y))
|
57 |
+
|
58 |
+
gradients = tape.gradient(loss, list(weights.values()) + list(biases.values()))
|
59 |
+
optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
|
60 |
+
optimizer.apply_gradients(zip(gradients, list(weights.values()) + list(biases.values())))
|
61 |
+
|
62 |
+
# Print loss every epoch
|
63 |
+
print(f"Epoch {epoch+1}, Loss: {loss.numpy()}")
|
64 |
+
|
65 |
+
# Evaluation
|
66 |
+
def evaluate(dataset):
|
67 |
+
correct_predictions = 0
|
68 |
+
total_predictions = 0
|
69 |
+
for batch_x, batch_y in dataset:
|
70 |
+
logits = feedforward_network(batch_x)
|
71 |
+
correct_predictions += tf.reduce_sum(tf.cast(tf.equal(tf.argmax(logits, 1), tf.argmax(batch_y, 1)), tf.int32)).numpy()
|
72 |
+
total_predictions += batch_x.shape[0]
|
73 |
+
return correct_predictions / total_predictions
|
74 |
+
|
75 |
+
accuracy = evaluate(test_dataset)
|
76 |
+
print(f"Test accuracy: {accuracy}")
|
77 |
+
|
78 |
+
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 784], dtype=tf.float32)])
|
79 |
+
def serve_model(x):
|
80 |
+
return {'output': feedforward_network(x)}
|
81 |
+
|
82 |
+
# Save the model
|
83 |
+
class MyModel(tf.Module):
|
84 |
+
def __init__(self, weights, biases):
|
85 |
+
super(MyModel, self).__init__()
|
86 |
+
self.weights = weights
|
87 |
+
self.biases = biases
|
88 |
+
self.serve_model = serve_model
|
89 |
+
|
90 |
+
model = MyModel(weights, biases)
|
91 |
+
save_path = os.path.join(job_dir, 'model')
|
92 |
+
tf.saved_model.save(model, save_path)
|