from typing import Optional import tensorflow as tf import matplotlib.pyplot as plt def visualize_data(dataset: tf.data.Dataset, data_augmentation: Optional[tf.keras.Sequential]=None) -> None: plt.figure(figsize=(10, 10)) for images, labels in dataset.take(1): for i in range(9): _ = plt.subplot(3, 3, i + 1) if data_augmentation != None: augmented_image = data_augmentation(images) plt.imshow(augmented_image[0].numpy().astype('uint8')) else: plt.imshow(images[i].numpy().astype('uint8')) plt.title(int(labels[i])) plt.axis('off') plt.show()