MiChaelinzo commited on
Commit
783ba85
·
1 Parent(s): c19c42e

Create cifar10_cnn_keras_tensor.py

Browse files
Files changed (1) hide show
  1. cifar10_cnn_keras_tensor.py +61 -0
cifar10_cnn_keras_tensor.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import datasets, layers, models
3
+ import matplotlib.pyplot as plt
4
+
5
+ # Load CIFAR-10 dataset
6
+ (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
7
+
8
+ # Normalize pixel values to be between 0 and 1
9
+ train_images, test_images = train_images / 255.0, test_images / 255.0
10
+
11
+ # Define the CNN model
12
+ model = models.Sequential()
13
+ model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
14
+ model.add(layers.MaxPooling2D((2, 2)))
15
+ model.add(layers.Conv2D(64, (3, 3), activation='relu'))
16
+ model.add(layers.MaxPooling2D((2, 2)))
17
+ model.add(layers.Conv2D(64, (3, 3), activation='relu'))
18
+ model.add(layers.Flatten())
19
+ model.add(layers.Dense(64, activation='relu'))
20
+ model.add(layers.Dense(10))
21
+
22
+ # Compile the model
23
+ model.compile(optimizer='adam',
24
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
25
+ metrics=['accuracy'])
26
+
27
+ # Define a callback to stop training when desired accuracy is achieved
28
+ class AccuracyCallback(tf.keras.callbacks.Callback):
29
+ def on_epoch_end(self, epoch, logs={}):
30
+ if logs.get('val_accuracy') > 0.90:
31
+ print("\nReached 90% accuracy, stopping training...")
32
+ self.model.stop_training = True
33
+ accuracy_callback = AccuracyCallback()
34
+
35
+ # Train the model
36
+ history = model.fit(train_images, train_labels, epochs=50,
37
+ validation_data=(test_images, test_labels),
38
+ callbacks=[accuracy_callback])
39
+
40
+ # Evaluate the model
41
+ test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
42
+ print('Test accuracy:', test_acc)
43
+
44
+ # Plot accuracy and loss curves
45
+ acc = history.history['accuracy']
46
+ val_acc = history.history['val_accuracy']
47
+ loss = history.history['loss']
48
+ val_loss = history.history['val_loss']
49
+ epochs = range(len(acc))
50
+ plt.figure(figsize=(10, 5))
51
+ plt.subplot(1, 2, 1)
52
+ plt.plot(epochs, acc, 'r', label='Training accuracy')
53
+ plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
54
+ plt.title('Training and validation accuracy')
55
+ plt.legend()
56
+ plt.subplot(1, 2, 2)
57
+ plt.plot(epochs, loss, 'r', label='Training loss')
58
+ plt.plot(epochs, val_loss, 'b', label='Validation loss')
59
+ plt.title('Training and validation loss')
60
+ plt.legend()
61
+ plt.show()