|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
import os |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from tensorflow.keras import regularizers |
|
|
|
assert 'COLAB_TPU_ADDR' in os.environ, 'Missin TPU?' |
|
if('COLAB_TPU_ADDR') in os.environ: |
|
TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR']) |
|
else: |
|
TF_MASTER = '' |
|
tpu_address = TF_MASTER |
|
|
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address) |
|
tf.config.experimental_connect_to_cluster(resolver) |
|
tf.tpu.experimental.initialize_tpu_system(resolver) |
|
|
|
|
|
strategy = tf.distribute.TPUStrategy(resolver) |
|
|
|
|
|
def create_model(): |
|
return tf.keras.Sequential([ |
|
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), |
|
tf.keras.layers.Dropout(0.25), |
|
|
|
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), |
|
tf.keras.layers.Dropout(0.25), |
|
|
|
tf.keras.layers.Flatten(), |
|
tf.keras.layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.Dropout(0.5), |
|
tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.Dropout(0.5), |
|
tf.keras.layers.Dense(10, activation='softmax') |
|
]) |
|
|
|
|
|
def get_dataset(batch_size, is_training=True): |
|
split = 'train' if is_training else 'test' |
|
dataset, info = tfds.load(name='mnist', split=split, with_info= True, as_supervised=True, try_gcs=True) |
|
def scale(image, label): |
|
image = tf.cast(image, tf.float32) |
|
image /= 255.0 |
|
return image, label |
|
dataset = dataset.map(scale) |
|
if is_training: |
|
dataset = dataset.shuffle(10000) |
|
dataset = dataset.repeat() |
|
dataset = dataset.batch(batch_size) |
|
return dataset |
|
|
|
|
|
|
|
with strategy.scope(): |
|
model = create_model() |
|
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) |
|
model.summary() |
|
|
|
|
|
|
|
|
|
batch_size = 512 |
|
train_dataset = get_dataset(batch_size, True) |
|
validation_dataset = get_dataset(batch_size, False) |
|
with strategy.scope(): |
|
model = create_model() |
|
model.compile(optimizer='adam', steps_per_execution=50, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) |
|
epochs = 80 |
|
steps_per_epoch = 60000 // batch_size |
|
validation_steps = 10000 // batch_size |
|
history = model.fit(train_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_data=validation_dataset, validation_steps=validation_steps) |
|
|
|
|
|
acc = history.history['sparse_categorical_accuracy'] |
|
val_acc = history.history['val_sparse_categorical_accuracy'] |
|
loss = history.history['loss'] |
|
val_loss = history.history['val_loss'] |
|
epochs_range = range(epochs) |
|
|
|
|
|
plt.figure(figsize=(15, 15)) |
|
plt.subplot(2, 2, 1) |
|
plt.plot(epochs_range, acc, label='Training Accuracy') |
|
plt.plot(epochs_range, val_acc, label='Validation Accuracy') |
|
plt.legend(loc='lower right') |
|
plt.title('Training and Validation Accuracy') |
|
|
|
plt.subplot(2, 2, 2) |
|
plt.plot(epochs_range, loss, label='Training Loss') |
|
plt.plot(epochs_range, val_loss, label='Validation Loss') |
|
plt.legend(loc='upper right') |
|
plt.title('Training and Validation Loss') |
|
plt.show() |
|
|
|
|
|
final_daset = validation_dataset.take(10) |
|
test_images, test_labels = next(iter(final_daset.take(10))) |
|
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] |
|
|
|
|
|
predictions = model.predict(test_images) |
|
|
|
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 6), |
|
subplot_kw={'xticks': [], 'yticks': []}) |
|
for i, ax in enumerate(axes.flat): |
|
|
|
ax.imshow(test_images[i]) |
|
|
|
true_label = class_names[test_labels[i]] |
|
pred_label = class_names[np.argmax(predictions[i])] |
|
if true_label == pred_label: |
|
ax.set_title("Это: {}, ИИ: {}".format(true_label, pred_label), color='green') |
|
else: |
|
ax.set_title("Это: {}, ИИ: {}".format(true_label, pred_label), color='red') |
|
|
|
plt.tight_layout() |
|
plt.show() |