MiChaelinzo
commited on
Commit
·
783ba85
1
Parent(s):
c19c42e
Create cifar10_cnn_keras_tensor.py
Browse files- cifar10_cnn_keras_tensor.py +61 -0
cifar10_cnn_keras_tensor.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import datasets, layers, models
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
# Load CIFAR-10 dataset
|
6 |
+
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
|
7 |
+
|
8 |
+
# Normalize pixel values to be between 0 and 1
|
9 |
+
train_images, test_images = train_images / 255.0, test_images / 255.0
|
10 |
+
|
11 |
+
# Define the CNN model
|
12 |
+
model = models.Sequential()
|
13 |
+
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
|
14 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
15 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
|
16 |
+
model.add(layers.MaxPooling2D((2, 2)))
|
17 |
+
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
|
18 |
+
model.add(layers.Flatten())
|
19 |
+
model.add(layers.Dense(64, activation='relu'))
|
20 |
+
model.add(layers.Dense(10))
|
21 |
+
|
22 |
+
# Compile the model
|
23 |
+
model.compile(optimizer='adam',
|
24 |
+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
25 |
+
metrics=['accuracy'])
|
26 |
+
|
27 |
+
# Define a callback to stop training when desired accuracy is achieved
|
28 |
+
class AccuracyCallback(tf.keras.callbacks.Callback):
|
29 |
+
def on_epoch_end(self, epoch, logs={}):
|
30 |
+
if logs.get('val_accuracy') > 0.90:
|
31 |
+
print("\nReached 90% accuracy, stopping training...")
|
32 |
+
self.model.stop_training = True
|
33 |
+
accuracy_callback = AccuracyCallback()
|
34 |
+
|
35 |
+
# Train the model
|
36 |
+
history = model.fit(train_images, train_labels, epochs=50,
|
37 |
+
validation_data=(test_images, test_labels),
|
38 |
+
callbacks=[accuracy_callback])
|
39 |
+
|
40 |
+
# Evaluate the model
|
41 |
+
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
|
42 |
+
print('Test accuracy:', test_acc)
|
43 |
+
|
44 |
+
# Plot accuracy and loss curves
|
45 |
+
acc = history.history['accuracy']
|
46 |
+
val_acc = history.history['val_accuracy']
|
47 |
+
loss = history.history['loss']
|
48 |
+
val_loss = history.history['val_loss']
|
49 |
+
epochs = range(len(acc))
|
50 |
+
plt.figure(figsize=(10, 5))
|
51 |
+
plt.subplot(1, 2, 1)
|
52 |
+
plt.plot(epochs, acc, 'r', label='Training accuracy')
|
53 |
+
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
|
54 |
+
plt.title('Training and validation accuracy')
|
55 |
+
plt.legend()
|
56 |
+
plt.subplot(1, 2, 2)
|
57 |
+
plt.plot(epochs, loss, 'r', label='Training loss')
|
58 |
+
plt.plot(epochs, val_loss, 'b', label='Validation loss')
|
59 |
+
plt.title('Training and validation loss')
|
60 |
+
plt.legend()
|
61 |
+
plt.show()
|